A Brain tumor is considered as one of the aggressive diseases, among children and adults. Application of automated classification techniques using Machine Learning(ML) and Artificial Intelligence(AI)has consistently shown higher accuracy than manual classification. Hence, proposing a system performing detection and classification by using Deep Learning Algorithms using ConvolutionNeural Network (CNN), Artificial Neural Network (ANN), and TransferLearning (TL) would be helpful to doctors all around the world.
The gold here is to identify tumor type among 'glioma_tumor','no_tumor','meningioma_tumor','pituitary_tumor'.
To Detect and Classify Brain Tumor using, CNN and TL; as an asset of Deep Learning and to examine the tumor position(segmentation).
Tensorflow CNN-based Brain Tumor Detection will be used.
import numpy as np
import os
import keras
import pandas as pd
import plotly.graph_objects as go
import plotly.subplots as sp
import plotly.express as px
import matplotlib.colors
import seaborn as sns
import matplotlib.pyplot as plt
from skimage.transform import resize
from sklearn.utils import shuffle
from tensorflow.keras.utils import to_categorical
### Creating the CNN Model
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.layers import Input, Dense, InputLayer, Flatten, Conv2D, MaxPooling2D, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from keras.models import Model, Sequential
from tensorflow.keras.models import Sequential, load_model
from keras import metrics
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
# Building Model
from keras.utils import plot_model
from tensorflow.keras import models
# Training Model
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
C:\Users\yanch\anaconda3\Lib\site-packages\paramiko\transport.py:219: CryptographyDeprecationWarning: Blowfish has been deprecated "class": algorithms.Blowfish,
colors_dark = ["#1F1F1F", "#313131", '#636363', '#AEAEAE', '#DADADA']
colors_red = ["#331313", "#582626", '#9E1717', '#D35151', '#E9B4B4']
colors_green = ['#01411C','#4B6F44','#4F7942','#74C365','#D0F0C0']
base_dir = 'C:\\Users\\yanch\\Desktop\\UC\\Classes\\2024 Spring\\ADSP 31009 Machine Learning and Predictive Analytics\\Final Project'
train_dir = os.path.join(base_dir, 'Training')
test_dir = os.path.join(base_dir, 'Testing')
labels = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']
from skimage.transform import resize
X_train = [] #Training Dataset
Y_train = [] #Training Labels
image_size=224
for label in labels:
path = os.path.join(train_dir, label)
class_num = labels.index(label)
for img in os.listdir(path):
img_array = plt.imread(os.path.join(path, img))
img_resized = resize(img_array, (image_size, image_size, 3))
X_train.append(img_resized)
Y_train.append(class_num)
for label in labels:
path = os.path.join(test_dir, label)
class_num = labels.index(label)
for img in os.listdir(path):
img_array = plt.imread(os.path.join(path, img))
img_resized = resize(img_array, (image_size, image_size, 3))
X_train.append(img_resized)
Y_train.append(class_num)
X_train = np.array(X_train)
Y_train = np.array(Y_train)
# Data generators
train_datagen = ImageDataGenerator(rescale=1/255,
rotation_range=90,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
vertical_flip=True,
validation_split=0.2)
valid_datagen = ImageDataGenerator(rescale=1/255, validation_split=0.2)
train_generator=train_datagen.flow_from_directory(train_dir,
target_size=(224,224), color_mode='rgb', shuffle=True,
subset='training', batch_size=32, class_mode='categorical')
val_generator = valid_datagen.flow_from_directory(train_dir,
target_size=(224,224), color_mode='rgb', shuffle=True,
subset='validation',batch_size=32,class_mode='categorical')
Found 2297 images belonging to 4 classes. Found 573 images belonging to 4 classes.
X_train.shape
(3264, 224, 224, 3)
# Shuffling data
X_train, Y_train = shuffle(X_train, Y_train, random_state=42)
#After shuffling sample size remains same
X_train.shape
(3264, 224, 224, 3)
# This method uses the classes array, which directly indicates the class index for each image
(unique, counts) = np.unique(train_generator.classes, return_counts=True)
class_counts = dict(zip(unique, counts))
# Mapping index to class names
class_names = {v: k for k, v in train_generator.class_indices.items()}
class_counts_named = {class_names[k]: v for k, v in class_counts.items()}
# Plotting
plt.figure(figsize=(10, 5))
plt.bar(class_counts_named.keys(), class_counts_named.values())
plt.title('Distribution of Classes in Training Data')
plt.xlabel('Class')
plt.ylabel('Number of Images')
plt.xticks(rotation=45)
plt.show()
Analyzing the distribution of pixel intensities can help in understanding the general characteristics of the images, like contrast and brightness, and might suggest necessary preprocessing steps like histogram equalization.
fig, ax = plt.subplots()
for img in X_train[:5]: # Use the same images from the first batch
sns.histplot(img.ravel(), label='Pixel Intensity', ax=ax, kde=True)
ax.set_title('Pixel Intensity Distribution')
ax.legend()
plt.show()
#plotting the images
plt.figure(figsize=(20,20))
for i in range(16):
plt.subplot(4,4,i+1)
plt.imshow(X_train[i])
plt.title(labels[Y_train[i]], fontsize=16, fontweight='bold')
plt.axis("off")
plt.show()
# Split the data into training and testing and validation
X_train, X_test, Y_train, Y_test = train_test_split(X_train, Y_train, test_size=0.2, random_state=42)
X_train, X_valid, Y_train, Y_valid = train_test_split(X_train, Y_train, test_size=0.1, random_state=42)
print(X_train.shape)
print(X_valid.shape)
print(X_test.shape)
print(Y_train.shape)
print(Y_test.shape)
print(Y_valid.shape)
(2349, 224, 224, 3) (262, 224, 224, 3) (653, 224, 224, 3) (2349,) (653,) (262,)
# Count the number of images in each class
class_counts = np.bincount(Y_train)
class_names = ['glioma', 'meningioma', 'no tumor', 'pituitary']
# Create a DataFrame with class names and counts
train_df = pd.DataFrame({'Class': class_names, 'Count': class_counts})
# Create a bar chart using matplotlib
fig, ax = plt.subplots()
# Plot the bar chart
ax.barh(train_df['Class'], train_df['Count'])
# Add title and labels
ax.set_title('Number of Images in Each Class of the Train Data')
ax.set_xlabel('Count')
ax.set_ylabel('Class')
# Display the plot
plt.show()
# convert string to categorical
from keras.utils import to_categorical
y_train_new = []
y_valid_new = []
y_test_new = []
for i in range(len(Y_train)):
y_train_new.append(to_categorical(Y_train[i], num_classes=4))
for i in range(len(Y_valid)):
y_valid_new.append(to_categorical(Y_valid[i], num_classes=4))
for i in range(len(Y_test)):
y_test_new.append(to_categorical(Y_test[i], num_classes=4))
y_train_new = np.array(y_train_new)
y_valid_new = np.array(y_valid_new)
y_test_new = np.array(y_test_new)
y_train_new.shape
(2349, 4)
y_test_new.shape
(653, 4)
Training Loss (Blue Line):
The training loss starts high and decreases sharply, flattening out around epoch 10. This indicates that the model is effectively learning from the training data and minimizing the loss.
Validation Loss (Orange Line):
The validation loss also starts high, decreasing rapidly at first and then fluctuating but stabilizing after epoch 10. The validation loss remains consistently low, indicating good generalization to the validation data despite some fluctuations.
The model shows strong overall performance with a high accuracy of 0.92 and consistently high precision, recall, and f1-scores across all classes.
Both training and validation losses decrease steadily and stabilize, indicating that the model is learning effectively and generalizing well to the validation set.
The fluctuations in validation loss suggest some variability, but the overall trend remains low.
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import InputLayer, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Model architecture without regularization
model = Sequential()
model.add(InputLayer(input_shape=(image_size, image_size, 3)))
model.add(Conv2D(16, kernel_size=(5, 5), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(256, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(4, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# Summary of the model
model.summary()
C:\Users\yanch\anaconda3\Lib\site-packages\keras\src\layers\core\input_layer.py:25: UserWarning: Argument `input_shape` is deprecated. Use `shape` instead. warnings.warn(
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ conv2d_10 (Conv2D) │ (None, 220, 220, 16) │ 1,216 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_10 (MaxPooling2D) │ (None, 110, 110, 16) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_11 (Conv2D) │ (None, 108, 108, 32) │ 4,640 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_11 (MaxPooling2D) │ (None, 54, 54, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_12 (Conv2D) │ (None, 52, 52, 64) │ 18,496 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_12 (MaxPooling2D) │ (None, 26, 26, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_13 (Conv2D) │ (None, 24, 24, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_13 (MaxPooling2D) │ (None, 12, 12, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_14 (Conv2D) │ (None, 10, 10, 256) │ 295,168 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_14 (MaxPooling2D) │ (None, 5, 5, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ flatten_2 (Flatten) │ (None, 6400) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_9 (Dense) │ (None, 512) │ 3,277,312 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_10 (Dense) │ (None, 4) │ 2,052 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 3,672,740 (14.01 MB)
Trainable params: 3,672,740 (14.01 MB)
Non-trainable params: 0 (0.00 B)
history = model.fit(X_train, y_train_new,
batch_size=64,
epochs=35,
steps_per_epoch=100,
validation_data=(X_valid, y_valid_new))
Epoch 1/35 37/100 ━━━━━━━━━━━━━━━━━━━━ 48s 773ms/step - accuracy: 0.3965 - loss: 1.2695
C:\Users\yanch\anaconda3\Lib\contextlib.py:155: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset. self.gen.throw(typ, value, traceback)
100/100 ━━━━━━━━━━━━━━━━━━━━ 35s 299ms/step - accuracy: 0.4503 - loss: 1.1917 - val_accuracy: 0.5992 - val_loss: 0.9012 Epoch 2/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 24s 238ms/step - accuracy: 0.6621 - loss: 0.8256 - val_accuracy: 0.7099 - val_loss: 0.6280 Epoch 3/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 25s 243ms/step - accuracy: 0.7634 - loss: 0.6076 - val_accuracy: 0.6908 - val_loss: 0.7620 Epoch 4/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 39s 218ms/step - accuracy: 0.7917 - loss: 0.5103 - val_accuracy: 0.8168 - val_loss: 0.4132 Epoch 5/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.8925 - loss: 0.2956 - val_accuracy: 0.7824 - val_loss: 0.5028 Epoch 6/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 222ms/step - accuracy: 0.9123 - loss: 0.2329 - val_accuracy: 0.8092 - val_loss: 0.4947 Epoch 7/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 217ms/step - accuracy: 0.9361 - loss: 0.1684 - val_accuracy: 0.8893 - val_loss: 0.4255 Epoch 8/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 219ms/step - accuracy: 0.9571 - loss: 0.1243 - val_accuracy: 0.8626 - val_loss: 0.4455 Epoch 9/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 221ms/step - accuracy: 0.9795 - loss: 0.0562 - val_accuracy: 0.8931 - val_loss: 0.4505 Epoch 10/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9851 - loss: 0.0488 - val_accuracy: 0.8740 - val_loss: 0.5628 Epoch 11/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9678 - loss: 0.0820 - val_accuracy: 0.8969 - val_loss: 0.4931 Epoch 12/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9799 - loss: 0.0519 - val_accuracy: 0.8817 - val_loss: 0.5548 Epoch 13/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9846 - loss: 0.0475 - val_accuracy: 0.8931 - val_loss: 0.6534 Epoch 14/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 226ms/step - accuracy: 0.9886 - loss: 0.0339 - val_accuracy: 0.8855 - val_loss: 0.6284 Epoch 15/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 213ms/step - accuracy: 0.9918 - loss: 0.0278 - val_accuracy: 0.9084 - val_loss: 0.5133 Epoch 16/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 212ms/step - accuracy: 0.9927 - loss: 0.0279 - val_accuracy: 0.9084 - val_loss: 0.5029 Epoch 17/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 24s 234ms/step - accuracy: 0.9956 - loss: 0.0128 - val_accuracy: 0.9275 - val_loss: 0.5578 Epoch 18/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9986 - loss: 0.0081 - val_accuracy: 0.8969 - val_loss: 0.7537 Epoch 19/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 225ms/step - accuracy: 0.9966 - loss: 0.0129 - val_accuracy: 0.9008 - val_loss: 0.5796 Epoch 20/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 216ms/step - accuracy: 0.9959 - loss: 0.0110 - val_accuracy: 0.9046 - val_loss: 0.6764 Epoch 21/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 24s 236ms/step - accuracy: 0.9989 - loss: 0.0064 - val_accuracy: 0.9084 - val_loss: 0.6961 Epoch 22/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 216ms/step - accuracy: 0.9993 - loss: 0.0053 - val_accuracy: 0.9160 - val_loss: 0.6376 Epoch 23/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 219ms/step - accuracy: 0.9992 - loss: 0.0044 - val_accuracy: 0.9160 - val_loss: 0.5882 Epoch 24/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 229ms/step - accuracy: 0.9994 - loss: 0.0037 - val_accuracy: 0.9198 - val_loss: 0.5992 Epoch 25/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9987 - loss: 0.0043 - val_accuracy: 0.9198 - val_loss: 0.5866 Epoch 26/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 223ms/step - accuracy: 0.9978 - loss: 0.0045 - val_accuracy: 0.9237 - val_loss: 0.5753 Epoch 27/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 216ms/step - accuracy: 0.9991 - loss: 0.0075 - val_accuracy: 0.9198 - val_loss: 0.5439 Epoch 28/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 221ms/step - accuracy: 0.9987 - loss: 0.0040 - val_accuracy: 0.9160 - val_loss: 0.5814 Epoch 29/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 222ms/step - accuracy: 0.9981 - loss: 0.0045 - val_accuracy: 0.9198 - val_loss: 0.5544 Epoch 30/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9993 - loss: 0.0044 - val_accuracy: 0.9160 - val_loss: 0.5499 Epoch 31/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 219ms/step - accuracy: 0.9992 - loss: 0.0028 - val_accuracy: 0.9160 - val_loss: 0.5503 Epoch 32/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 213ms/step - accuracy: 0.9992 - loss: 0.0031 - val_accuracy: 0.9160 - val_loss: 0.5302 Epoch 33/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 214ms/step - accuracy: 0.9989 - loss: 0.0029 - val_accuracy: 0.9160 - val_loss: 0.5438 Epoch 34/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 23s 225ms/step - accuracy: 0.9989 - loss: 0.0026 - val_accuracy: 0.9160 - val_loss: 0.5614 Epoch 35/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 22s 213ms/step - accuracy: 0.9987 - loss: 0.0036 - val_accuracy: 0.9160 - val_loss: 0.5497
# Save the model
# this is baseline model with rotation range = 20
model.save('cnn_model_1.keras')
import matplotlib.pyplot as plt
# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('training_validation_loss.png')
plt.show()
# Predict the val model
y_pred = model.predict(X_valid)
y_pred = np.argmax(y_pred, axis=1)
# Calculate accuracy
accuracy = accuracy_score(Y_valid, y_pred)
print('Val Accuracy = %.4f' % accuracy)
9/9 ━━━━━━━━━━━━━━━━━━━━ 1s 80ms/step Val Accuracy = 0.9160
# Predict the test model
y_pred = model.predict(X_test)
y_pred = np.argmax(y_pred, axis=1)
# Calculate accuracy
accuracy = accuracy_score(Y_test, y_pred)
print('Test Accuracy = %.4f' % accuracy)
21/21 ━━━━━━━━━━━━━━━━━━━━ 2s 102ms/step Test Accuracy = 0.9250
print("Classification Report:\n",classification_report(Y_test, y_pred))
Classification Report:
precision recall f1-score support
0 0.91 0.90 0.91 198
1 0.89 0.91 0.90 183
2 0.95 0.88 0.92 104
3 0.96 0.99 0.98 168
accuracy 0.92 653
macro avg 0.93 0.92 0.93 653
weighted avg 0.93 0.92 0.92 653
It initializes the model to accept input images of size (image_size, image_size, 3), which corresponds to image height, image width, and 3 color channels (RGB).
Conv2D with 16 filters: Applies a 5x5 convolution kernel to extract features such as edges and textures. The use of 16 filters means it will output 16 different feature maps.
BatchNormalization: Normalizes the activations from the previous layer, which helps in accelerating the training process and stabilizing the learning by normalizing the input layer by re-centering and re-scaling.
MaxPooling2D: Reduces the spatial dimensions (height and width) of the input volume to the next layer by taking the maximum value over a 2x2 pooling window. This helps in reducing the computational cost and overfitting by providing an abstracted form of the representation.
Dropout (0.2): Randomly sets the outgoing edges of 20% of the neurons to zero during training, to prevent overfitting.
These blocks increase in the number of filters (32, 64, 128, 256). Increasing the number of filters allows the network to capture more complex patterns like textures and shapes.
Each block follows a similar structure: a convolution layer, batch normalization, max pooling, and dropout. This repeated structure helps the network in learning hierarchically more complex features at each level.
Kernel sizes are generally smaller (3x3) in subsequent layers, which is common as deeper layers capture higher-level abstract features where finer granularity is less important.
The output of the final convolutional layer is flattened (converted from a matrix to a vector), so it can be fed into the dense layers.
Dense Layer with 512 neurons: This layer is fully connected and uses ReLU activation. It serves as a classifier on the features formed by the convolutions and pooling layers. Dropout (0.2): Again used here to reduce overfitting.
Dense layer with 4 neurons: This implies the model is intended for a classification task with 4 classes. The softmax activation function is used to output a probability distribution over the 4 classes.
The model uses the Adam optimizer, a popular choice for deep learning tasks as it combines the best properties of the AdaGrad and RMSProp algorithms to optimize its weights.
The loss function is categorical_crossentropy, suitable for multi-class classification problems.
The metric used to evaluate the model is accuracy.
#simple CNN per with augment
model = Sequential()
model.add(InputLayer(input_shape=(image_size, image_size,3)))
model.add(Conv2D(16, kernel_size=(5, 5), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Conv2D(64, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))
model.add(Conv2D(128, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))
model.add(Conv2D(256, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(4, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
C:\Users\yanch\anaconda3\Lib\site-packages\keras\src\layers\core\input_layer.py:25: UserWarning: Argument `input_shape` is deprecated. Use `shape` instead.
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ conv2d_10 (Conv2D) │ (None, 220, 220, 16) │ 1,216 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_4 │ (None, 220, 220, 16) │ 64 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_10 (MaxPooling2D) │ (None, 110, 110, 16) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_10 (Dropout) │ (None, 110, 110, 16) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_11 (Conv2D) │ (None, 108, 108, 32) │ 4,640 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_5 │ (None, 108, 108, 32) │ 128 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_11 (MaxPooling2D) │ (None, 54, 54, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_12 (Conv2D) │ (None, 52, 52, 64) │ 18,496 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_12 (MaxPooling2D) │ (None, 26, 26, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_11 (Dropout) │ (None, 26, 26, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_13 (Conv2D) │ (None, 24, 24, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_13 (MaxPooling2D) │ (None, 12, 12, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_12 (Dropout) │ (None, 12, 12, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_14 (Conv2D) │ (None, 10, 10, 256) │ 295,168 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_14 (MaxPooling2D) │ (None, 5, 5, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_13 (Dropout) │ (None, 5, 5, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ flatten_2 (Flatten) │ (None, 6400) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_4 (Dense) │ (None, 512) │ 3,277,312 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_14 (Dropout) │ (None, 512) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_5 (Dense) │ (None, 4) │ 2,052 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 3,672,932 (14.01 MB)
Trainable params: 3,672,836 (14.01 MB)
Non-trainable params: 96 (384.00 B)
The disparity between training and validation/test accuracy along with the low test scores suggests the model may be underfitting, as it does not perform well on any of the datasets.
history = model.fit(X_train, y_train_new,
batch_size=64,
epochs=10,
steps_per_epoch=5,
validation_data=(X_valid, y_valid_new))
Epoch 1/10 5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9951 - loss: 0.0164 - val_accuracy: 0.9389 - val_loss: 0.5294 Epoch 2/10 5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9983 - loss: 0.0069 - val_accuracy: 0.9313 - val_loss: 0.5413 Epoch 3/10 5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9803 - loss: 0.0544 - val_accuracy: 0.9351 - val_loss: 0.5765 Epoch 4/10 5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9889 - loss: 0.0235 - val_accuracy: 0.9389 - val_loss: 0.6128 Epoch 5/10 5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9902 - loss: 0.0434 - val_accuracy: 0.9389 - val_loss: 0.6036 Epoch 6/10 5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9746 - loss: 0.1208 - val_accuracy: 0.9351 - val_loss: 0.4841 Epoch 7/10 5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9867 - loss: 0.0426 - val_accuracy: 0.9275 - val_loss: 0.5177 Epoch 8/10 2/5 ━━━━━━━━━━━━━━━━━━━━ 2s 700ms/step - accuracy: 0.9752 - loss: 0.0742
C:\Users\yanch\anaconda3\Lib\contextlib.py:155: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.
5/5 ━━━━━━━━━━━━━━━━━━━━ 3s 405ms/step - accuracy: 0.9791 - loss: 0.0702 - val_accuracy: 0.9237 - val_loss: 0.4940 Epoch 9/10 5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9889 - loss: 0.0358 - val_accuracy: 0.9084 - val_loss: 0.5404 Epoch 10/10 5/5 ━━━━━━━━━━━━━━━━━━━━ 6s 1s/step - accuracy: 0.9845 - loss: 0.0548 - val_accuracy: 0.9198 - val_loss: 0.5243
# Save the model
# this is baseline model with rotation range = 20
model.save('new_cnn_model_1.keras')
import matplotlib.pyplot as plt
# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('training_validation_loss.png')
plt.show()
# Predict the val model
y_pred = model.predict(X_valid)
y_pred = np.argmax(y_pred, axis=1)
# Calculate accuracy
accuracy = accuracy_score(Y_valid, y_pred)
print('Val Accuracy = %.4f' % accuracy)
9/9 ━━━━━━━━━━━━━━━━━━━━ 1s 124ms/step Val Accuracy = 0.2672
# Predict the test model
y_pred = model.predict(X_test)
y_pred = np.argmax(y_pred, axis=1)
# Calculate accuracy
accuracy = accuracy_score(Y_test, y_pred)
print('Test Accuracy = %.4f' % accuracy)
21/21 ━━━━━━━━━━━━━━━━━━━━ 3s 132ms/step Test Accuracy = 0.2450
print("Classification Report:\n",classification_report(Y_test, y_pred))
Classification Report:
precision recall f1-score support
0 0.00 0.00 0.00 219
1 0.20 0.01 0.01 187
2 0.17 0.80 0.29 87
3 0.36 0.56 0.44 160
accuracy 0.25 653
macro avg 0.18 0.34 0.18 653
weighted avg 0.17 0.25 0.15 653
C:\Users\yanch\anaconda3\Lib\site-packages\sklearn\metrics\_classification.py:1469: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. C:\Users\yanch\anaconda3\Lib\site-packages\sklearn\metrics\_classification.py:1469: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. C:\Users\yanch\anaconda3\Lib\site-packages\sklearn\metrics\_classification.py:1469: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
Model Performance: The model demonstrates excellent learning capability and generalizes well to unseen data. The balance between training and validation performance suggests that the model configurations, including architecture adjustments, regularization techniques, or hyperparameters, are well-tuned.
Stability and Overfitting: The relatively smooth and convergent training and validation loss curves indicate that the model is stable and not overfitting. This is corroborated by the close tracking of validation loss to training loss.
The training and validation loss curves show a desirable behavior. Training loss steadily decreases, indicating good learning progress. Validation loss decreases alongside and remains close to the training loss, which is a good sign of the model not overfitting.
history = model.fit(X_train, y_train_new,
batch_size=64,
epochs=50,
steps_per_epoch=50,
validation_data=(X_valid, y_valid_new))
Epoch 1/50 37/50 ━━━━━━━━━━━━━━━━━━━━ 12s 981ms/step - accuracy: 0.5946 - loss: 0.9527
C:\Users\yanch\anaconda3\Lib\contextlib.py:155: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.
50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 750ms/step - accuracy: 0.6032 - loss: 0.9389 - val_accuracy: 0.4008 - val_loss: 1.2256 Epoch 2/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 738ms/step - accuracy: 0.7068 - loss: 0.7582 - val_accuracy: 0.6145 - val_loss: 1.0696 Epoch 3/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 730ms/step - accuracy: 0.7281 - loss: 0.6811 - val_accuracy: 0.5496 - val_loss: 1.0466 Epoch 4/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 731ms/step - accuracy: 0.7868 - loss: 0.5578 - val_accuracy: 0.6374 - val_loss: 0.9908 Epoch 5/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 743ms/step - accuracy: 0.8076 - loss: 0.4926 - val_accuracy: 0.6718 - val_loss: 0.8159 Epoch 6/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 738ms/step - accuracy: 0.8473 - loss: 0.4301 - val_accuracy: 0.6908 - val_loss: 0.8416 Epoch 7/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 732ms/step - accuracy: 0.8699 - loss: 0.3598 - val_accuracy: 0.7443 - val_loss: 0.6991 Epoch 8/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 487s 10s/step - accuracy: 0.8710 - loss: 0.3516 - val_accuracy: 0.7595 - val_loss: 0.6429 Epoch 9/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 724ms/step - accuracy: 0.8879 - loss: 0.2956 - val_accuracy: 0.8130 - val_loss: 0.5012 Epoch 10/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 726ms/step - accuracy: 0.9031 - loss: 0.2607 - val_accuracy: 0.8588 - val_loss: 0.4286 Epoch 11/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 746ms/step - accuracy: 0.9337 - loss: 0.2054 - val_accuracy: 0.8702 - val_loss: 0.4034 Epoch 12/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 735ms/step - accuracy: 0.9091 - loss: 0.2255 - val_accuracy: 0.9008 - val_loss: 0.3470 Epoch 13/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 780ms/step - accuracy: 0.9400 - loss: 0.1757 - val_accuracy: 0.8817 - val_loss: 0.4146 Epoch 14/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 753ms/step - accuracy: 0.9507 - loss: 0.1474 - val_accuracy: 0.8817 - val_loss: 0.4246 Epoch 15/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 754ms/step - accuracy: 0.9528 - loss: 0.1394 - val_accuracy: 0.9084 - val_loss: 0.3712 Epoch 16/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 747ms/step - accuracy: 0.9588 - loss: 0.1137 - val_accuracy: 0.9122 - val_loss: 0.3470 Epoch 17/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 732ms/step - accuracy: 0.9605 - loss: 0.1113 - val_accuracy: 0.9237 - val_loss: 0.3751 Epoch 18/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 734ms/step - accuracy: 0.9551 - loss: 0.1080 - val_accuracy: 0.8931 - val_loss: 0.3842 Epoch 19/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 748ms/step - accuracy: 0.9650 - loss: 0.0972 - val_accuracy: 0.9160 - val_loss: 0.3153 Epoch 20/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 37s 735ms/step - accuracy: 0.9733 - loss: 0.0816 - val_accuracy: 0.9237 - val_loss: 0.4023 Epoch 21/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 764ms/step - accuracy: 0.9799 - loss: 0.0745 - val_accuracy: 0.9313 - val_loss: 0.3908 Epoch 22/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 41s 817ms/step - accuracy: 0.9625 - loss: 0.0931 - val_accuracy: 0.9275 - val_loss: 0.3757 Epoch 23/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 41s 806ms/step - accuracy: 0.9834 - loss: 0.0574 - val_accuracy: 0.9313 - val_loss: 0.3387 Epoch 24/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 40s 793ms/step - accuracy: 0.9730 - loss: 0.0750 - val_accuracy: 0.9313 - val_loss: 0.2972 Epoch 25/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 40s 798ms/step - accuracy: 0.9832 - loss: 0.0579 - val_accuracy: 0.9160 - val_loss: 0.4324 Epoch 26/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 40s 764ms/step - accuracy: 0.9682 - loss: 0.1009 - val_accuracy: 0.9237 - val_loss: 0.4243 Epoch 27/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 756ms/step - accuracy: 0.9774 - loss: 0.0579 - val_accuracy: 0.9351 - val_loss: 0.3809 Epoch 28/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 768ms/step - accuracy: 0.9821 - loss: 0.0566 - val_accuracy: 0.9427 - val_loss: 0.3873 Epoch 29/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 759ms/step - accuracy: 0.9894 - loss: 0.0359 - val_accuracy: 0.9427 - val_loss: 0.4556 Epoch 30/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 757ms/step - accuracy: 0.9863 - loss: 0.0442 - val_accuracy: 0.9160 - val_loss: 0.4866 Epoch 31/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 760ms/step - accuracy: 0.9858 - loss: 0.0473 - val_accuracy: 0.9313 - val_loss: 0.4852 Epoch 32/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 783ms/step - accuracy: 0.9872 - loss: 0.0445 - val_accuracy: 0.9389 - val_loss: 0.4071 Epoch 33/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 41s 810ms/step - accuracy: 0.9916 - loss: 0.0268 - val_accuracy: 0.9198 - val_loss: 0.5157 Epoch 34/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 761ms/step - accuracy: 0.9872 - loss: 0.0383 - val_accuracy: 0.9275 - val_loss: 0.4700 Epoch 35/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 765ms/step - accuracy: 0.9849 - loss: 0.0430 - val_accuracy: 0.9160 - val_loss: 0.4828 Epoch 36/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 755ms/step - accuracy: 0.9837 - loss: 0.0496 - val_accuracy: 0.9275 - val_loss: 0.5223 Epoch 37/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 756ms/step - accuracy: 0.9831 - loss: 0.0462 - val_accuracy: 0.9389 - val_loss: 0.4216 Epoch 38/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 757ms/step - accuracy: 0.9848 - loss: 0.0453 - val_accuracy: 0.9427 - val_loss: 0.4807 Epoch 39/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 749ms/step - accuracy: 0.9832 - loss: 0.0521 - val_accuracy: 0.9198 - val_loss: 0.6065 Epoch 40/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 771ms/step - accuracy: 0.9840 - loss: 0.0441 - val_accuracy: 0.9389 - val_loss: 0.5364 Epoch 41/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 42s 829ms/step - accuracy: 0.9882 - loss: 0.0404 - val_accuracy: 0.9313 - val_loss: 0.4418 Epoch 42/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 777ms/step - accuracy: 0.9881 - loss: 0.0338 - val_accuracy: 0.9427 - val_loss: 0.4785 Epoch 43/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 760ms/step - accuracy: 0.9925 - loss: 0.0230 - val_accuracy: 0.9580 - val_loss: 0.4233 Epoch 44/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 780ms/step - accuracy: 0.9899 - loss: 0.0393 - val_accuracy: 0.9427 - val_loss: 0.4645 Epoch 45/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 764ms/step - accuracy: 0.9915 - loss: 0.0212 - val_accuracy: 0.9389 - val_loss: 0.5971 Epoch 46/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 761ms/step - accuracy: 0.9898 - loss: 0.0299 - val_accuracy: 0.9504 - val_loss: 0.5630 Epoch 47/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 39s 778ms/step - accuracy: 0.9895 - loss: 0.0299 - val_accuracy: 0.9427 - val_loss: 0.4841 Epoch 48/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 759ms/step - accuracy: 0.9892 - loss: 0.0384 - val_accuracy: 0.9466 - val_loss: 0.4732 Epoch 49/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 40s 786ms/step - accuracy: 0.9900 - loss: 0.0299 - val_accuracy: 0.9466 - val_loss: 0.4952 Epoch 50/50 50/50 ━━━━━━━━━━━━━━━━━━━━ 38s 763ms/step - accuracy: 0.9898 - loss: 0.0312 - val_accuracy: 0.9427 - val_loss: 0.5609
# Save the model
# this is baseline model with rotation range = 20
model.save('new_cnn_model_2.keras')
import matplotlib.pyplot as plt
# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('training_validation_loss.png')
plt.show()
# Predict the val model
y_pred = model.predict(X_valid)
y_pred = np.argmax(y_pred, axis=1)
# Calculate accuracy
accuracy = accuracy_score(Y_valid, y_pred)
print('Val Accuracy = %.4f' % accuracy)
9/9 ━━━━━━━━━━━━━━━━━━━━ 1s 101ms/step Val Accuracy = 0.9427
# Predict the test model
y_pred = model.predict(X_test)
y_pred = np.argmax(y_pred, axis=1)
# Calculate accuracy
accuracy = accuracy_score(Y_test, y_pred)
print('Test Accuracy = %.4f' % accuracy)
21/21 ━━━━━━━━━━━━━━━━━━━━ 2s 110ms/step Test Accuracy = 0.9173
print("Classification Report:\n",classification_report(Y_test, y_pred))
Classification Report:
precision recall f1-score support
0 0.88 0.94 0.91 219
1 0.95 0.82 0.88 187
2 0.92 0.92 0.92 87
3 0.94 0.99 0.96 160
accuracy 0.92 653
macro avg 0.92 0.92 0.92 653
weighted avg 0.92 0.92 0.92 653
Loss Stability Concerns: Despite excellent performance metrics, the variability in the validation loss could still be a concern. It may suggest that the model could start overfitting if trained for more epochs without adjustments.
history = model.fit(X_train, y_train_new,
batch_size=64,
epochs=50,
steps_per_epoch=100,
validation_data=(X_valid, y_valid_new))
Epoch 1/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 41s 394ms/step - accuracy: 0.9861 - loss: 0.0356 - val_accuracy: 0.9313 - val_loss: 0.5516 Epoch 2/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 53s 515ms/step - accuracy: 0.9854 - loss: 0.0395 - val_accuracy: 0.9427 - val_loss: 0.4838 Epoch 3/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 49s 479ms/step - accuracy: 0.9863 - loss: 0.0466 - val_accuracy: 0.9389 - val_loss: 0.5167 Epoch 4/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 49s 481ms/step - accuracy: 0.9904 - loss: 0.0240 - val_accuracy: 0.9504 - val_loss: 0.4565 Epoch 5/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 48s 469ms/step - accuracy: 0.9896 - loss: 0.0271 - val_accuracy: 0.9351 - val_loss: 0.5200 Epoch 6/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 515ms/step - accuracy: 0.9911 - loss: 0.0332 - val_accuracy: 0.9427 - val_loss: 0.4462 Epoch 7/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 427ms/step - accuracy: 0.9931 - loss: 0.0218 - val_accuracy: 0.9427 - val_loss: 0.4205 Epoch 8/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 40s 394ms/step - accuracy: 0.9918 - loss: 0.0275 - val_accuracy: 0.9313 - val_loss: 0.4520 Epoch 9/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 45s 448ms/step - accuracy: 0.9854 - loss: 0.0446 - val_accuracy: 0.9237 - val_loss: 0.5280 Epoch 10/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 419ms/step - accuracy: 0.9859 - loss: 0.0413 - val_accuracy: 0.9313 - val_loss: 0.5361 Epoch 11/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 40s 396ms/step - accuracy: 0.9908 - loss: 0.0263 - val_accuracy: 0.9466 - val_loss: 0.5306 Epoch 12/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 44s 432ms/step - accuracy: 0.9872 - loss: 0.0373 - val_accuracy: 0.9389 - val_loss: 0.4706 Epoch 13/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 426ms/step - accuracy: 0.9905 - loss: 0.0328 - val_accuracy: 0.9351 - val_loss: 0.4630 Epoch 14/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 41s 406ms/step - accuracy: 0.9916 - loss: 0.0253 - val_accuracy: 0.9313 - val_loss: 0.5127 Epoch 15/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 414ms/step - accuracy: 0.9857 - loss: 0.0457 - val_accuracy: 0.9160 - val_loss: 0.5916 Epoch 16/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 422ms/step - accuracy: 0.9826 - loss: 0.0566 - val_accuracy: 0.9389 - val_loss: 0.4681 Epoch 17/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 42s 418ms/step - accuracy: 0.9882 - loss: 0.0367 - val_accuracy: 0.9237 - val_loss: 0.4785 Epoch 18/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 407ms/step - accuracy: 0.9899 - loss: 0.0242 - val_accuracy: 0.9046 - val_loss: 0.7800 Epoch 19/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 51s 498ms/step - accuracy: 0.9911 - loss: 0.0288 - val_accuracy: 0.9427 - val_loss: 0.5952 Epoch 20/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 56s 549ms/step - accuracy: 0.9956 - loss: 0.0171 - val_accuracy: 0.9351 - val_loss: 0.6594 Epoch 21/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 53s 522ms/step - accuracy: 0.9913 - loss: 0.0264 - val_accuracy: 0.9427 - val_loss: 0.6845 Epoch 22/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 56s 555ms/step - accuracy: 0.9953 - loss: 0.0183 - val_accuracy: 0.9427 - val_loss: 0.5893 Epoch 23/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 55s 535ms/step - accuracy: 0.9968 - loss: 0.0114 - val_accuracy: 0.9427 - val_loss: 0.6818 Epoch 24/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 512ms/step - accuracy: 0.9953 - loss: 0.0230 - val_accuracy: 0.9389 - val_loss: 0.6214 Epoch 25/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 54s 530ms/step - accuracy: 0.9949 - loss: 0.0188 - val_accuracy: 0.9351 - val_loss: 0.6380 Epoch 26/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 510ms/step - accuracy: 0.9944 - loss: 0.0229 - val_accuracy: 0.9389 - val_loss: 0.7436 Epoch 27/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 53s 517ms/step - accuracy: 0.9942 - loss: 0.0166 - val_accuracy: 0.9389 - val_loss: 0.6644 Epoch 28/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 53s 526ms/step - accuracy: 0.9872 - loss: 0.0349 - val_accuracy: 0.9275 - val_loss: 0.7268 Epoch 29/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 53s 516ms/step - accuracy: 0.9936 - loss: 0.0243 - val_accuracy: 0.9466 - val_loss: 0.6768 Epoch 30/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 515ms/step - accuracy: 0.9888 - loss: 0.0291 - val_accuracy: 0.9313 - val_loss: 0.5553 Epoch 31/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 513ms/step - accuracy: 0.9836 - loss: 0.0569 - val_accuracy: 0.9313 - val_loss: 0.6402 Epoch 32/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 51s 503ms/step - accuracy: 0.9914 - loss: 0.0280 - val_accuracy: 0.9504 - val_loss: 0.5774 Epoch 33/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 41s 399ms/step - accuracy: 0.9852 - loss: 0.0441 - val_accuracy: 0.9389 - val_loss: 0.4621 Epoch 34/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 42s 412ms/step - accuracy: 0.9868 - loss: 0.0392 - val_accuracy: 0.9466 - val_loss: 0.4100 Epoch 35/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 54s 534ms/step - accuracy: 0.9945 - loss: 0.0211 - val_accuracy: 0.9427 - val_loss: 0.4357 Epoch 36/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 60s 598ms/step - accuracy: 0.9942 - loss: 0.0218 - val_accuracy: 0.9427 - val_loss: 0.5096 Epoch 37/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 48s 474ms/step - accuracy: 0.9911 - loss: 0.0231 - val_accuracy: 0.9504 - val_loss: 0.4844 Epoch 38/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 56s 552ms/step - accuracy: 0.9934 - loss: 0.0187 - val_accuracy: 0.9389 - val_loss: 0.5852 Epoch 39/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 412ms/step - accuracy: 0.9952 - loss: 0.0111 - val_accuracy: 0.9466 - val_loss: 0.6731 Epoch 40/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 419ms/step - accuracy: 0.9937 - loss: 0.0232 - val_accuracy: 0.9351 - val_loss: 0.7436 Epoch 41/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 41s 406ms/step - accuracy: 0.9859 - loss: 0.0451 - val_accuracy: 0.9580 - val_loss: 0.4898 Epoch 42/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 518ms/step - accuracy: 0.9888 - loss: 0.0285 - val_accuracy: 0.9351 - val_loss: 0.6509 Epoch 43/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 51s 496ms/step - accuracy: 0.9893 - loss: 0.0317 - val_accuracy: 0.9313 - val_loss: 0.4865 Epoch 44/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 50s 490ms/step - accuracy: 0.9880 - loss: 0.0427 - val_accuracy: 0.9466 - val_loss: 0.5625 Epoch 45/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 51s 506ms/step - accuracy: 0.9873 - loss: 0.0441 - val_accuracy: 0.9466 - val_loss: 0.4856 Epoch 46/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 53s 520ms/step - accuracy: 0.9909 - loss: 0.0250 - val_accuracy: 0.9275 - val_loss: 0.6252 Epoch 47/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 52s 510ms/step - accuracy: 0.9846 - loss: 0.0482 - val_accuracy: 0.9351 - val_loss: 0.5449 Epoch 48/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 50s 491ms/step - accuracy: 0.9914 - loss: 0.0350 - val_accuracy: 0.9313 - val_loss: 0.5463 Epoch 49/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 49s 485ms/step - accuracy: 0.9886 - loss: 0.0433 - val_accuracy: 0.9427 - val_loss: 0.6445 Epoch 50/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 46s 446ms/step - accuracy: 0.9814 - loss: 0.0612 - val_accuracy: 0.9237 - val_loss: 0.5988
# Save the model
model.save('new_cnn_model_3.keras')
import matplotlib.pyplot as plt
# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('training_validation_loss.png')
plt.show()
# Predict the val model
y_pred = model.predict(X_valid)
y_pred = np.argmax(y_pred, axis=1)
# Calculate accuracy
accuracy = accuracy_score(Y_valid, y_pred)
print('Val Accuracy = %.4f' % accuracy)
9/9 ━━━━━━━━━━━━━━━━━━━━ 1s 99ms/step Val Accuracy = 0.9237
# Predict the test model
y_pred = model.predict(X_test)
y_pred = np.argmax(y_pred, axis=1)
# Calculate accuracy
accuracy = accuracy_score(Y_test, y_pred)
print('Test Accuracy = %.4f' % accuracy)
21/21 ━━━━━━━━━━━━━━━━━━━━ 2s 109ms/step Test Accuracy = 0.9449
print("Classification Report:\n",classification_report(Y_test, y_pred))
Classification Report:
precision recall f1-score support
0 0.97 0.92 0.94 219
1 0.95 0.93 0.94 187
2 0.88 0.94 0.91 87
3 0.95 1.00 0.97 160
accuracy 0.94 653
macro avg 0.94 0.95 0.94 653
weighted avg 0.95 0.94 0.94 653
history = model.fit(X_train, y_train_new,
batch_size=64,
epochs=35,
steps_per_epoch=100,
validation_data=(X_valid, y_valid_new))
Epoch 1/35 37/100 ━━━━━━━━━━━━━━━━━━━━ 1:02 991ms/step - accuracy: 0.3305 - loss: 4.8506
C:\Users\yanch\anaconda3\Lib\contextlib.py:155: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.
100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 375ms/step - accuracy: 0.3932 - loss: 3.3018 - val_accuracy: 0.4466 - val_loss: 1.3470 Epoch 2/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 37s 368ms/step - accuracy: 0.5852 - loss: 0.9946 - val_accuracy: 0.3130 - val_loss: 1.3530 Epoch 3/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 371ms/step - accuracy: 0.6493 - loss: 0.8459 - val_accuracy: 0.3053 - val_loss: 1.4348 Epoch 4/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 369ms/step - accuracy: 0.6894 - loss: 0.7776 - val_accuracy: 0.3244 - val_loss: 1.5073 Epoch 5/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 373ms/step - accuracy: 0.7075 - loss: 0.7224 - val_accuracy: 0.3550 - val_loss: 1.4179 Epoch 6/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 378ms/step - accuracy: 0.7537 - loss: 0.6409 - val_accuracy: 0.3588 - val_loss: 1.4829 Epoch 7/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 368ms/step - accuracy: 0.7665 - loss: 0.5970 - val_accuracy: 0.4504 - val_loss: 1.2290 Epoch 8/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 374ms/step - accuracy: 0.8033 - loss: 0.5161 - val_accuracy: 0.4924 - val_loss: 1.2251 Epoch 9/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 371ms/step - accuracy: 0.8198 - loss: 0.4714 - val_accuracy: 0.5534 - val_loss: 1.1228 Epoch 10/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 372ms/step - accuracy: 0.8283 - loss: 0.4445 - val_accuracy: 0.7099 - val_loss: 0.7793 Epoch 11/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 45s 444ms/step - accuracy: 0.8432 - loss: 0.4166 - val_accuracy: 0.7519 - val_loss: 0.6910 Epoch 12/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 43s 425ms/step - accuracy: 0.8540 - loss: 0.3607 - val_accuracy: 0.7366 - val_loss: 0.6631 Epoch 13/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 44s 432ms/step - accuracy: 0.8700 - loss: 0.3242 - val_accuracy: 0.7290 - val_loss: 0.7293 Epoch 14/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 42s 410ms/step - accuracy: 0.8770 - loss: 0.3202 - val_accuracy: 0.8282 - val_loss: 0.4950 Epoch 15/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 47s 461ms/step - accuracy: 0.9050 - loss: 0.2347 - val_accuracy: 0.7710 - val_loss: 0.7423 Epoch 16/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 42s 411ms/step - accuracy: 0.9131 - loss: 0.2261 - val_accuracy: 0.8473 - val_loss: 0.4228 Epoch 17/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 48s 474ms/step - accuracy: 0.9374 - loss: 0.1804 - val_accuracy: 0.8855 - val_loss: 0.3896 Epoch 18/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 46s 449ms/step - accuracy: 0.9257 - loss: 0.1934 - val_accuracy: 0.8893 - val_loss: 0.3702 Epoch 19/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 40s 394ms/step - accuracy: 0.9385 - loss: 0.1609 - val_accuracy: 0.9046 - val_loss: 0.3787 Epoch 20/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 42s 409ms/step - accuracy: 0.9522 - loss: 0.1312 - val_accuracy: 0.8969 - val_loss: 0.4051 Epoch 21/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 42s 412ms/step - accuracy: 0.9461 - loss: 0.1377 - val_accuracy: 0.8435 - val_loss: 0.4281 Epoch 22/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 45s 441ms/step - accuracy: 0.9588 - loss: 0.1241 - val_accuracy: 0.8969 - val_loss: 0.3753 Epoch 23/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 40s 391ms/step - accuracy: 0.9568 - loss: 0.1206 - val_accuracy: 0.9122 - val_loss: 0.3589 Epoch 24/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 39s 383ms/step - accuracy: 0.9659 - loss: 0.0923 - val_accuracy: 0.9122 - val_loss: 0.3198 Epoch 25/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 41s 401ms/step - accuracy: 0.9632 - loss: 0.1027 - val_accuracy: 0.8969 - val_loss: 0.3980 Epoch 26/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 372ms/step - accuracy: 0.9676 - loss: 0.0879 - val_accuracy: 0.8969 - val_loss: 0.4138 Epoch 27/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 372ms/step - accuracy: 0.9734 - loss: 0.0846 - val_accuracy: 0.8893 - val_loss: 0.4401 Epoch 28/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 39s 381ms/step - accuracy: 0.9737 - loss: 0.0844 - val_accuracy: 0.9198 - val_loss: 0.3196 Epoch 29/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 40s 370ms/step - accuracy: 0.9690 - loss: 0.0896 - val_accuracy: 0.8740 - val_loss: 0.4665 Epoch 30/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 376ms/step - accuracy: 0.9777 - loss: 0.0713 - val_accuracy: 0.9160 - val_loss: 0.3465 Epoch 31/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 378ms/step - accuracy: 0.9777 - loss: 0.0761 - val_accuracy: 0.9084 - val_loss: 0.3995 Epoch 32/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 39s 382ms/step - accuracy: 0.9735 - loss: 0.0870 - val_accuracy: 0.9160 - val_loss: 0.3470 Epoch 33/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 374ms/step - accuracy: 0.9787 - loss: 0.0633 - val_accuracy: 0.9198 - val_loss: 0.3858 Epoch 34/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 38s 377ms/step - accuracy: 0.9821 - loss: 0.0524 - val_accuracy: 0.9237 - val_loss: 0.3363 Epoch 35/35 100/100 ━━━━━━━━━━━━━━━━━━━━ 39s 381ms/step - accuracy: 0.9829 - loss: 0.0523 - val_accuracy: 0.9237 - val_loss: 0.3193
# Save the model
model.save('new_cnn_model_6.keras')
import matplotlib.pyplot as plt
# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('training_validation_loss.png')
plt.show()
# Predict the val model
y_pred = model.predict(X_valid)
y_pred = np.argmax(y_pred, axis=1)
# Calculate accuracy
accuracy = accuracy_score(Y_valid, y_pred)
print('Val Accuracy = %.4f' % accuracy)
9/9 ━━━━━━━━━━━━━━━━━━━━ 1s 127ms/step Val Accuracy = 0.9237
# Predict the test model
y_pred = model.predict(X_test)
y_pred = np.argmax(y_pred, axis=1)
# Calculate accuracy
accuracy = accuracy_score(Y_test, y_pred)
print('Test Accuracy = %.4f' % accuracy)
21/21 ━━━━━━━━━━━━━━━━━━━━ 3s 125ms/step Test Accuracy = 0.9280
print("Classification Report:\n",classification_report(Y_test, y_pred))
Classification Report:
precision recall f1-score support
0 0.90 0.93 0.92 219
1 0.95 0.86 0.90 187
2 0.87 0.94 0.91 87
3 0.98 1.00 0.99 160
accuracy 0.93 653
macro avg 0.92 0.93 0.93 653
weighted avg 0.93 0.93 0.93 653
Validation Accuracy: The model achieves a validation accuracy of 92.75%.
Test Accuracy: The test accuracy is even higher at 90.96%.
Precision and Recall: All classes show strong precision (85-96%) and recall (80-100%). This indicates that the model is not only correctly identifying positive cases but is also precise in its predictions, minimizing false positives.
F1-Score: High F1-scores across all classes (89-97%) suggest a balanced performance between precision and recall, which is crucial for reliable classification.
Accuracy Curve: The training accuracy plateaus close to 100%, while the validation accuracy stabilizes at a high level but with some gap compared to the training, indicating a slight overfitting but still within an acceptable range. Loss Curve: Training loss decreases sharply and flattens, which is ideal. However, the validation loss, despite decreasing, shows more fluctuations, which is typical but should be monitored to ensure it doesn't start to diverge from the training loss significantly.
I've implemented useful callbacks like EarlyStopping and ReduceLROnPlateau, which are beneficial for handling overfitting and optimizing the training process:
EarlyStopping is configured to monitor the training loss, stopping the training if there are no improvements beyond a minimal delta, indicating that continuing training is inefficient.
ReduceLROnPlateau reduces the learning rate when the validation loss stops improving, helping the model to fine-tune adjustments in weights and potentially escape local minima.
The EarlyStopping and ReduceLROnPlateau are both callbacks in Keras that serve as training interventions to improve the training process and prevent overfitting. Each of these has specific roles and is used to monitor different aspects of the model during training. Let’s delve into the goals and functionalities of each:
EarlyStopping Goal: To halt the training process early if there is no significant improvement in a specified metric over a defined number of epochs. This is particularly useful in avoiding overfitting and unnecessarily long training times.
ReduceLROnPlateau Goal: To reduce the learning rate when a metric has stopped improving. This helps the model to fine-tune and potentially escape local minima during training. Lowering the learning rate can allow the model to make smaller changes to the weights and potentially discover better minima.
# Stop training if loss doesn't keep decreasing.
model_es = EarlyStopping(monitor='loss', min_delta=1e-9, patience=12, verbose=True)
model_rlr = ReduceLROnPlateau(monitor='val_loss', factor=0.3, patience=6, verbose=True)
history = model.fit(X_train, y_train_new, batch_size=64, epochs=100, validation_data=(X_valid, y_valid_new),
callbacks=[model_es, model_rlr])
Epoch 1/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.3408 - loss: 4.1512 - val_accuracy: 0.1641 - val_loss: 1.3980 - learning_rate: 0.0010 Epoch 2/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 46s 1s/step - accuracy: 0.6022 - loss: 0.9985 - val_accuracy: 0.3397 - val_loss: 1.3390 - learning_rate: 0.0010 Epoch 3/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 40s 1s/step - accuracy: 0.6243 - loss: 0.8712 - val_accuracy: 0.3588 - val_loss: 1.3417 - learning_rate: 0.0010 Epoch 4/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.6738 - loss: 0.7678 - val_accuracy: 0.3626 - val_loss: 1.5721 - learning_rate: 0.0010 Epoch 5/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.7683 - loss: 0.5974 - val_accuracy: 0.3588 - val_loss: 1.3576 - learning_rate: 0.0010 Epoch 6/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.7721 - loss: 0.5671 - val_accuracy: 0.4924 - val_loss: 1.1619 - learning_rate: 0.0010 Epoch 7/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 40s 1s/step - accuracy: 0.8083 - loss: 0.4945 - val_accuracy: 0.6565 - val_loss: 0.8985 - learning_rate: 0.0010 Epoch 8/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.8143 - loss: 0.4453 - val_accuracy: 0.6603 - val_loss: 0.8329 - learning_rate: 0.0010 Epoch 9/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.8614 - loss: 0.3480 - val_accuracy: 0.7214 - val_loss: 0.7044 - learning_rate: 0.0010 Epoch 10/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 40s 1s/step - accuracy: 0.8816 - loss: 0.3137 - val_accuracy: 0.7481 - val_loss: 0.6579 - learning_rate: 0.0010 Epoch 11/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.8578 - loss: 0.3504 - val_accuracy: 0.8092 - val_loss: 0.5405 - learning_rate: 0.0010 Epoch 12/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 40s 1s/step - accuracy: 0.9146 - loss: 0.2361 - val_accuracy: 0.8206 - val_loss: 0.5122 - learning_rate: 0.0010 Epoch 13/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9134 - loss: 0.2439 - val_accuracy: 0.8969 - val_loss: 0.4035 - learning_rate: 0.0010 Epoch 14/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 42s 1s/step - accuracy: 0.9283 - loss: 0.1944 - val_accuracy: 0.8931 - val_loss: 0.3910 - learning_rate: 0.0010 Epoch 15/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 45s 1s/step - accuracy: 0.9340 - loss: 0.2014 - val_accuracy: 0.8931 - val_loss: 0.3655 - learning_rate: 0.0010 Epoch 16/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.9562 - loss: 0.1278 - val_accuracy: 0.8893 - val_loss: 0.4169 - learning_rate: 0.0010 Epoch 17/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.9317 - loss: 0.1774 - val_accuracy: 0.9160 - val_loss: 0.3436 - learning_rate: 0.0010 Epoch 18/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.9531 - loss: 0.1303 - val_accuracy: 0.8893 - val_loss: 0.4071 - learning_rate: 0.0010 Epoch 19/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 42s 1s/step - accuracy: 0.9679 - loss: 0.1018 - val_accuracy: 0.9237 - val_loss: 0.3675 - learning_rate: 0.0010 Epoch 20/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 43s 1s/step - accuracy: 0.9721 - loss: 0.0828 - val_accuracy: 0.9237 - val_loss: 0.3545 - learning_rate: 0.0010 Epoch 21/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9722 - loss: 0.0815 - val_accuracy: 0.8550 - val_loss: 0.4985 - learning_rate: 0.0010 Epoch 22/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 40s 1s/step - accuracy: 0.9559 - loss: 0.1091 - val_accuracy: 0.9008 - val_loss: 0.4358 - learning_rate: 0.0010 Epoch 23/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9783 - loss: 0.0598 Epoch 23: ReduceLROnPlateau reducing learning rate to 0.0003000000142492354. 37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.9783 - loss: 0.0600 - val_accuracy: 0.8931 - val_loss: 0.4360 - learning_rate: 0.0010 Epoch 24/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9809 - loss: 0.0616 - val_accuracy: 0.9237 - val_loss: 0.4380 - learning_rate: 3.0000e-04 Epoch 25/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.9850 - loss: 0.0484 - val_accuracy: 0.9198 - val_loss: 0.4186 - learning_rate: 3.0000e-04 Epoch 26/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9890 - loss: 0.0320 - val_accuracy: 0.9313 - val_loss: 0.4050 - learning_rate: 3.0000e-04 Epoch 27/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9886 - loss: 0.0322 - val_accuracy: 0.9275 - val_loss: 0.3922 - learning_rate: 3.0000e-04 Epoch 28/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.9924 - loss: 0.0293 - val_accuracy: 0.9275 - val_loss: 0.4017 - learning_rate: 3.0000e-04 Epoch 29/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9883 - loss: 0.0285 Epoch 29: ReduceLROnPlateau reducing learning rate to 9.000000427477062e-05. 37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9883 - loss: 0.0286 - val_accuracy: 0.9237 - val_loss: 0.4063 - learning_rate: 3.0000e-04 Epoch 30/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.9908 - loss: 0.0232 - val_accuracy: 0.9275 - val_loss: 0.4082 - learning_rate: 9.0000e-05 Epoch 31/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 41s 1s/step - accuracy: 0.9893 - loss: 0.0254 - val_accuracy: 0.9198 - val_loss: 0.4178 - learning_rate: 9.0000e-05 Epoch 32/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.9930 - loss: 0.0190 - val_accuracy: 0.9237 - val_loss: 0.4098 - learning_rate: 9.0000e-05 Epoch 33/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 41s 1s/step - accuracy: 0.9945 - loss: 0.0169 - val_accuracy: 0.9275 - val_loss: 0.4167 - learning_rate: 9.0000e-05 Epoch 34/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9964 - loss: 0.0160 - val_accuracy: 0.9237 - val_loss: 0.4290 - learning_rate: 9.0000e-05 Epoch 35/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9949 - loss: 0.0200 Epoch 35: ReduceLROnPlateau reducing learning rate to 2.700000040931627e-05. 37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9949 - loss: 0.0200 - val_accuracy: 0.9237 - val_loss: 0.4500 - learning_rate: 9.0000e-05 Epoch 36/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 38s 1s/step - accuracy: 0.9931 - loss: 0.0202 - val_accuracy: 0.9198 - val_loss: 0.4375 - learning_rate: 2.7000e-05 Epoch 37/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 39s 1s/step - accuracy: 0.9936 - loss: 0.0185 - val_accuracy: 0.9198 - val_loss: 0.4368 - learning_rate: 2.7000e-05 Epoch 38/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 47s 1s/step - accuracy: 0.9951 - loss: 0.0168 - val_accuracy: 0.9237 - val_loss: 0.4384 - learning_rate: 2.7000e-05 Epoch 39/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 44s 1s/step - accuracy: 0.9930 - loss: 0.0242 - val_accuracy: 0.9237 - val_loss: 0.4341 - learning_rate: 2.7000e-05 Epoch 40/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.9914 - loss: 0.0349 - val_accuracy: 0.9275 - val_loss: 0.4438 - learning_rate: 2.7000e-05 Epoch 41/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9928 - loss: 0.0206 Epoch 41: ReduceLROnPlateau reducing learning rate to 8.100000013655517e-06. 37/37 ━━━━━━━━━━━━━━━━━━━━ 47s 1s/step - accuracy: 0.9928 - loss: 0.0207 - val_accuracy: 0.9275 - val_loss: 0.4437 - learning_rate: 2.7000e-05 Epoch 42/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 46s 1s/step - accuracy: 0.9938 - loss: 0.0207 - val_accuracy: 0.9275 - val_loss: 0.4443 - learning_rate: 8.1000e-06 Epoch 43/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.9900 - loss: 0.0334 - val_accuracy: 0.9275 - val_loss: 0.4419 - learning_rate: 8.1000e-06 Epoch 44/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 46s 1s/step - accuracy: 0.9928 - loss: 0.0202 - val_accuracy: 0.9275 - val_loss: 0.4428 - learning_rate: 8.1000e-06 Epoch 45/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 42s 1s/step - accuracy: 0.9939 - loss: 0.0176 - val_accuracy: 0.9275 - val_loss: 0.4450 - learning_rate: 8.1000e-06 Epoch 46/100 37/37 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.9941 - loss: 0.0169 - val_accuracy: 0.9275 - val_loss: 0.4445 - learning_rate: 8.1000e-06 Epoch 46: early stopping
# Save the model
# this is baseline model with rotation range = 20
model.save('new_cnn_model1.keras')
# Predict the val model
y_pred = model.predict(X_valid)
y_pred = np.argmax(y_pred, axis=1)
# Calculate accuracy
accuracy = accuracy_score(Y_valid, y_pred)
print('Val Accuracy = %.4f' % accuracy)
9/9 ━━━━━━━━━━━━━━━━━━━━ 1s 125ms/step Val Accuracy = 0.9275
# Predict the test model
y_pred = model.predict(X_test)
y_pred = np.argmax(y_pred, axis=1)
# Calculate accuracy
accuracy = accuracy_score(Y_test, y_pred)
print('Test Accuracy = %.4f' % accuracy)
21/21 ━━━━━━━━━━━━━━━━━━━━ 2s 106ms/step Test Accuracy = 0.9096
print("Classification Report:\n",classification_report(Y_test, y_pred))
Classification Report:
precision recall f1-score support
0 0.85 0.93 0.89 198
1 0.91 0.80 0.85 183
2 0.96 0.90 0.93 104
3 0.95 1.00 0.97 168
accuracy 0.91 653
macro avg 0.92 0.91 0.91 653
weighted avg 0.91 0.91 0.91 653
Analysis:
High Specificity in Some Classes: The model is highly specific in recognizing pituitary tumors and generally good at identifying glioma tumors.
Challenges with Meningioma: There seems to be some confusion between meningioma and glioma tumors, which might require further investigation. Feature similarities between these types could be causing the model to struggle in differentiating them accurately.
Potential for Serious Misclassification: The misclassification between tumorous and non-tumorous scans, although low, is a critical error and should be minimized as much as possible.
labels = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']
# Define the custom color map
custom_colors = ['#01411C','#4B6F44','#4F7942','#74C365','#D0F0C0']
custom_cmap = matplotlib.colors.ListedColormap(custom_colors)
# Calculate confusion matrix
confusion_matrix = confusion_matrix(Y_test, y_pred)
# Create a display object with the custom color map
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor'])
# Plot the confusion matrix
fig, ax = plt.subplots()
disp.plot(cmap=custom_cmap, ax=ax)
# Set the title and axis labels
fig.text(s='Heatmap of the Confusion Matrix',size=18,fontweight='bold',
fontname='monospace',color=colors_dark[1],y=0.92,x=0.10,alpha=0.8)
# Rotate x-axis labels
plt.xticks(rotation=45)
# Save the figure
plt.savefig('CM CNN-2.png', dpi=300, bbox_inches='tight')
# Show the plot
plt.show()
fig, ax = plt.subplots(1, 2, figsize=(10, 5), facecolor='white')
# Plot training and validation accuracy
ax[0].plot(history.history['accuracy'])
ax[0].plot(history.history['val_accuracy'])
ax[0].set_title('Model Accuracy')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Accuracy')
ax[0].legend(['Train', 'Validation'], loc='upper left')
# Plot training and validation loss
ax[1].plot(history.history['loss'])
ax[1].plot(history.history['val_loss'])
ax[1].set_title('Model Loss')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Loss')
ax[1].legend(['Train', 'Validation'], loc='upper right')
# Save the figure
plt.savefig('plot CNN-2.png', dpi=300, bbox_inches='tight')
plt.tight_layout()
plt.show()
from sklearn.metrics import roc_curve, roc_auc_score
import numpy as np
# Compute predicted probabilities for each class
y_probs = model.predict(X_test)
# Ensure that the target labels Y_test are in a 2-dimensional format
if len(Y_test.shape) == 1:
Y_test = np.eye(len(np.unique(Y_test)))[Y_test.astype(int)]
# Compute the ROC curve and AUC score for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(Y_test.shape[1]):
fpr[i], tpr[i], _ = roc_curve(Y_test[:, i], y_probs[:, i])
roc_auc[i] = roc_auc_score(Y_test[:, i], y_probs[:, i])
# Plot the ROC curve for each class
plt.figure()
for i in range(Y_test.shape[1]):
plt.plot(fpr[i], tpr[i], label=f'Class {i} (AUC = {roc_auc[i]:.2f})')
# Set the title and axis labels
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc='lower right')
# Save the figure
plt.savefig('ROC CNN-2.png', dpi=300, bbox_inches='tight')
# Show the plot
plt.show()
21/21 ━━━━━━━━━━━━━━━━━━━━ 2s 107ms/step
class_labels=['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']
plt.figure(figsize=(16,20))
for i in range(16):
plt.subplot(4,4,i+1)
plt.imshow(X_test[i])
actual_label_idx = np.argmax(Y_test[i]) # Assuming Y_test is one-hot encoded
predicted_label_idx = np.argmax(y_pred[i]) # Assuming y_pred is one-hot encoded
plt.title(f"Actual label:{class_labels[actual_label_idx]}\nPredicted label:{class_labels[predicted_label_idx]}")
plt.axis("off")
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
num_classes = len(labels)
# Define EfficientNet model
base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(image_size, image_size, 3))
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
# Compile the model
model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
Model: "functional_29"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ input_layer_4 │ (None, 224, 224, │ 0 │ - │ │ (InputLayer) │ 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ rescaling_4 │ (None, 224, 224, │ 0 │ input_layer_4[0]… │ │ (Rescaling) │ 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization_2 │ (None, 224, 224, │ 7 │ rescaling_4[0][0] │ │ (Normalization) │ 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ rescaling_5 │ (None, 224, 224, │ 0 │ normalization_2[… │ │ (Rescaling) │ 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_conv_pad │ (None, 225, 225, │ 0 │ rescaling_5[0][0] │ │ (ZeroPadding2D) │ 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_conv (Conv2D) │ (None, 112, 112, │ 864 │ stem_conv_pad[0]… │ │ │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_bn │ (None, 112, 112, │ 128 │ stem_conv[0][0] │ │ (BatchNormalizatio… │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_activation │ (None, 112, 112, │ 0 │ stem_bn[0][0] │ │ (Activation) │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_dwconv │ (None, 112, 112, │ 288 │ stem_activation[… │ │ (DepthwiseConv2D) │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_bn │ (None, 112, 112, │ 128 │ block1a_dwconv[0… │ │ (BatchNormalizatio… │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_activation │ (None, 112, 112, │ 0 │ block1a_bn[0][0] │ │ (Activation) │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_se_squeeze │ (None, 32) │ 0 │ block1a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_se_reshape │ (None, 1, 1, 32) │ 0 │ block1a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_se_reduce │ (None, 1, 1, 8) │ 264 │ block1a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_se_expand │ (None, 1, 1, 32) │ 288 │ block1a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_se_excite │ (None, 112, 112, │ 0 │ block1a_activati… │ │ (Multiply) │ 32) │ │ block1a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_project_co… │ (None, 112, 112, │ 512 │ block1a_se_excit… │ │ (Conv2D) │ 16) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_project_bn │ (None, 112, 112, │ 64 │ block1a_project_… │ │ (BatchNormalizatio… │ 16) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_expand_conv │ (None, 112, 112, │ 1,536 │ block1a_project_… │ │ (Conv2D) │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_expand_bn │ (None, 112, 112, │ 384 │ block2a_expand_c… │ │ (BatchNormalizatio… │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_expand_act… │ (None, 112, 112, │ 0 │ block2a_expand_b… │ │ (Activation) │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_dwconv_pad │ (None, 113, 113, │ 0 │ block2a_expand_a… │ │ (ZeroPadding2D) │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_dwconv │ (None, 56, 56, │ 864 │ block2a_dwconv_p… │ │ (DepthwiseConv2D) │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_bn │ (None, 56, 56, │ 384 │ block2a_dwconv[0… │ │ (BatchNormalizatio… │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_activation │ (None, 56, 56, │ 0 │ block2a_bn[0][0] │ │ (Activation) │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_se_squeeze │ (None, 96) │ 0 │ block2a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_se_reshape │ (None, 1, 1, 96) │ 0 │ block2a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_se_reduce │ (None, 1, 1, 4) │ 388 │ block2a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_se_expand │ (None, 1, 1, 96) │ 480 │ block2a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_se_excite │ (None, 56, 56, │ 0 │ block2a_activati… │ │ (Multiply) │ 96) │ │ block2a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_project_co… │ (None, 56, 56, │ 2,304 │ block2a_se_excit… │ │ (Conv2D) │ 24) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_project_bn │ (None, 56, 56, │ 96 │ block2a_project_… │ │ (BatchNormalizatio… │ 24) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_expand_conv │ (None, 56, 56, │ 3,456 │ block2a_project_… │ │ (Conv2D) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_expand_bn │ (None, 56, 56, │ 576 │ block2b_expand_c… │ │ (BatchNormalizatio… │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_expand_act… │ (None, 56, 56, │ 0 │ block2b_expand_b… │ │ (Activation) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_dwconv │ (None, 56, 56, │ 1,296 │ block2b_expand_a… │ │ (DepthwiseConv2D) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_bn │ (None, 56, 56, │ 576 │ block2b_dwconv[0… │ │ (BatchNormalizatio… │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_activation │ (None, 56, 56, │ 0 │ block2b_bn[0][0] │ │ (Activation) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_se_squeeze │ (None, 144) │ 0 │ block2b_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_se_reshape │ (None, 1, 1, 144) │ 0 │ block2b_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_se_reduce │ (None, 1, 1, 6) │ 870 │ block2b_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_se_expand │ (None, 1, 1, 144) │ 1,008 │ block2b_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_se_excite │ (None, 56, 56, │ 0 │ block2b_activati… │ │ (Multiply) │ 144) │ │ block2b_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_project_co… │ (None, 56, 56, │ 3,456 │ block2b_se_excit… │ │ (Conv2D) │ 24) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_project_bn │ (None, 56, 56, │ 96 │ block2b_project_… │ │ (BatchNormalizatio… │ 24) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_drop │ (None, 56, 56, │ 0 │ block2b_project_… │ │ (Dropout) │ 24) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_add (Add) │ (None, 56, 56, │ 0 │ block2b_drop[0][… │ │ │ 24) │ │ block2a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_expand_conv │ (None, 56, 56, │ 3,456 │ block2b_add[0][0] │ │ (Conv2D) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_expand_bn │ (None, 56, 56, │ 576 │ block3a_expand_c… │ │ (BatchNormalizatio… │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_expand_act… │ (None, 56, 56, │ 0 │ block3a_expand_b… │ │ (Activation) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_dwconv_pad │ (None, 59, 59, │ 0 │ block3a_expand_a… │ │ (ZeroPadding2D) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_dwconv │ (None, 28, 28, │ 3,600 │ block3a_dwconv_p… │ │ (DepthwiseConv2D) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_bn │ (None, 28, 28, │ 576 │ block3a_dwconv[0… │ │ (BatchNormalizatio… │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_activation │ (None, 28, 28, │ 0 │ block3a_bn[0][0] │ │ (Activation) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_se_squeeze │ (None, 144) │ 0 │ block3a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_se_reshape │ (None, 1, 1, 144) │ 0 │ block3a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_se_reduce │ (None, 1, 1, 6) │ 870 │ block3a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_se_expand │ (None, 1, 1, 144) │ 1,008 │ block3a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_se_excite │ (None, 28, 28, │ 0 │ block3a_activati… │ │ (Multiply) │ 144) │ │ block3a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_project_co… │ (None, 28, 28, │ 5,760 │ block3a_se_excit… │ │ (Conv2D) │ 40) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_project_bn │ (None, 28, 28, │ 160 │ block3a_project_… │ │ (BatchNormalizatio… │ 40) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_expand_conv │ (None, 28, 28, │ 9,600 │ block3a_project_… │ │ (Conv2D) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_expand_bn │ (None, 28, 28, │ 960 │ block3b_expand_c… │ │ (BatchNormalizatio… │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_expand_act… │ (None, 28, 28, │ 0 │ block3b_expand_b… │ │ (Activation) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_dwconv │ (None, 28, 28, │ 6,000 │ block3b_expand_a… │ │ (DepthwiseConv2D) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_bn │ (None, 28, 28, │ 960 │ block3b_dwconv[0… │ │ (BatchNormalizatio… │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_activation │ (None, 28, 28, │ 0 │ block3b_bn[0][0] │ │ (Activation) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_se_squeeze │ (None, 240) │ 0 │ block3b_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_se_reshape │ (None, 1, 1, 240) │ 0 │ block3b_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_se_reduce │ (None, 1, 1, 10) │ 2,410 │ block3b_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_se_expand │ (None, 1, 1, 240) │ 2,640 │ block3b_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_se_excite │ (None, 28, 28, │ 0 │ block3b_activati… │ │ (Multiply) │ 240) │ │ block3b_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_project_co… │ (None, 28, 28, │ 9,600 │ block3b_se_excit… │ │ (Conv2D) │ 40) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_project_bn │ (None, 28, 28, │ 160 │ block3b_project_… │ │ (BatchNormalizatio… │ 40) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_drop │ (None, 28, 28, │ 0 │ block3b_project_… │ │ (Dropout) │ 40) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_add (Add) │ (None, 28, 28, │ 0 │ block3b_drop[0][… │ │ │ 40) │ │ block3a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_expand_conv │ (None, 28, 28, │ 9,600 │ block3b_add[0][0] │ │ (Conv2D) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_expand_bn │ (None, 28, 28, │ 960 │ block4a_expand_c… │ │ (BatchNormalizatio… │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_expand_act… │ (None, 28, 28, │ 0 │ block4a_expand_b… │ │ (Activation) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_dwconv_pad │ (None, 29, 29, │ 0 │ block4a_expand_a… │ │ (ZeroPadding2D) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_dwconv │ (None, 14, 14, │ 2,160 │ block4a_dwconv_p… │ │ (DepthwiseConv2D) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_bn │ (None, 14, 14, │ 960 │ block4a_dwconv[0… │ │ (BatchNormalizatio… │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_activation │ (None, 14, 14, │ 0 │ block4a_bn[0][0] │ │ (Activation) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_squeeze │ (None, 240) │ 0 │ block4a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_reshape │ (None, 1, 1, 240) │ 0 │ block4a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_reduce │ (None, 1, 1, 10) │ 2,410 │ block4a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_expand │ (None, 1, 1, 240) │ 2,640 │ block4a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_excite │ (None, 14, 14, │ 0 │ block4a_activati… │ │ (Multiply) │ 240) │ │ block4a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_project_co… │ (None, 14, 14, │ 19,200 │ block4a_se_excit… │ │ (Conv2D) │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_project_bn │ (None, 14, 14, │ 320 │ block4a_project_… │ │ (BatchNormalizatio… │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_expand_conv │ (None, 14, 14, │ 38,400 │ block4a_project_… │ │ (Conv2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_expand_bn │ (None, 14, 14, │ 1,920 │ block4b_expand_c… │ │ (BatchNormalizatio… │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_expand_act… │ (None, 14, 14, │ 0 │ block4b_expand_b… │ │ (Activation) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_dwconv │ (None, 14, 14, │ 4,320 │ block4b_expand_a… │ │ (DepthwiseConv2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_bn │ (None, 14, 14, │ 1,920 │ block4b_dwconv[0… │ │ (BatchNormalizatio… │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_activation │ (None, 14, 14, │ 0 │ block4b_bn[0][0] │ │ (Activation) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_squeeze │ (None, 480) │ 0 │ block4b_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_reshape │ (None, 1, 1, 480) │ 0 │ block4b_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_reduce │ (None, 1, 1, 20) │ 9,620 │ block4b_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_expand │ (None, 1, 1, 480) │ 10,080 │ block4b_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_excite │ (None, 14, 14, │ 0 │ block4b_activati… │ │ (Multiply) │ 480) │ │ block4b_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_project_co… │ (None, 14, 14, │ 38,400 │ block4b_se_excit… │ │ (Conv2D) │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_project_bn │ (None, 14, 14, │ 320 │ block4b_project_… │ │ (BatchNormalizatio… │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_drop │ (None, 14, 14, │ 0 │ block4b_project_… │ │ (Dropout) │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_add (Add) │ (None, 14, 14, │ 0 │ block4b_drop[0][… │ │ │ 80) │ │ block4a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_expand_conv │ (None, 14, 14, │ 38,400 │ block4b_add[0][0] │ │ (Conv2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_expand_bn │ (None, 14, 14, │ 1,920 │ block4c_expand_c… │ │ (BatchNormalizatio… │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_expand_act… │ (None, 14, 14, │ 0 │ block4c_expand_b… │ │ (Activation) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_dwconv │ (None, 14, 14, │ 4,320 │ block4c_expand_a… │ │ (DepthwiseConv2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_bn │ (None, 14, 14, │ 1,920 │ block4c_dwconv[0… │ │ (BatchNormalizatio… │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_activation │ (None, 14, 14, │ 0 │ block4c_bn[0][0] │ │ (Activation) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_squeeze │ (None, 480) │ 0 │ block4c_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_reshape │ (None, 1, 1, 480) │ 0 │ block4c_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_reduce │ (None, 1, 1, 20) │ 9,620 │ block4c_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_expand │ (None, 1, 1, 480) │ 10,080 │ block4c_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_excite │ (None, 14, 14, │ 0 │ block4c_activati… │ │ (Multiply) │ 480) │ │ block4c_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_project_co… │ (None, 14, 14, │ 38,400 │ block4c_se_excit… │ │ (Conv2D) │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_project_bn │ (None, 14, 14, │ 320 │ block4c_project_… │ │ (BatchNormalizatio… │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_drop │ (None, 14, 14, │ 0 │ block4c_project_… │ │ (Dropout) │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_add (Add) │ (None, 14, 14, │ 0 │ block4c_drop[0][… │ │ │ 80) │ │ block4b_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_expand_conv │ (None, 14, 14, │ 38,400 │ block4c_add[0][0] │ │ (Conv2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_expand_bn │ (None, 14, 14, │ 1,920 │ block5a_expand_c… │ │ (BatchNormalizatio… │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_expand_act… │ (None, 14, 14, │ 0 │ block5a_expand_b… │ │ (Activation) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_dwconv │ (None, 14, 14, │ 12,000 │ block5a_expand_a… │ │ (DepthwiseConv2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_bn │ (None, 14, 14, │ 1,920 │ block5a_dwconv[0… │ │ (BatchNormalizatio… │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_activation │ (None, 14, 14, │ 0 │ block5a_bn[0][0] │ │ (Activation) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_squeeze │ (None, 480) │ 0 │ block5a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_reshape │ (None, 1, 1, 480) │ 0 │ block5a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_reduce │ (None, 1, 1, 20) │ 9,620 │ block5a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_expand │ (None, 1, 1, 480) │ 10,080 │ block5a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_excite │ (None, 14, 14, │ 0 │ block5a_activati… │ │ (Multiply) │ 480) │ │ block5a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_project_co… │ (None, 14, 14, │ 53,760 │ block5a_se_excit… │ │ (Conv2D) │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_project_bn │ (None, 14, 14, │ 448 │ block5a_project_… │ │ (BatchNormalizatio… │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_expand_conv │ (None, 14, 14, │ 75,264 │ block5a_project_… │ │ (Conv2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_expand_bn │ (None, 14, 14, │ 2,688 │ block5b_expand_c… │ │ (BatchNormalizatio… │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_expand_act… │ (None, 14, 14, │ 0 │ block5b_expand_b… │ │ (Activation) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_dwconv │ (None, 14, 14, │ 16,800 │ block5b_expand_a… │ │ (DepthwiseConv2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_bn │ (None, 14, 14, │ 2,688 │ block5b_dwconv[0… │ │ (BatchNormalizatio… │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_activation │ (None, 14, 14, │ 0 │ block5b_bn[0][0] │ │ (Activation) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_squeeze │ (None, 672) │ 0 │ block5b_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_reshape │ (None, 1, 1, 672) │ 0 │ block5b_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_reduce │ (None, 1, 1, 28) │ 18,844 │ block5b_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_expand │ (None, 1, 1, 672) │ 19,488 │ block5b_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_excite │ (None, 14, 14, │ 0 │ block5b_activati… │ │ (Multiply) │ 672) │ │ block5b_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_project_co… │ (None, 14, 14, │ 75,264 │ block5b_se_excit… │ │ (Conv2D) │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_project_bn │ (None, 14, 14, │ 448 │ block5b_project_… │ │ (BatchNormalizatio… │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_drop │ (None, 14, 14, │ 0 │ block5b_project_… │ │ (Dropout) │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_add (Add) │ (None, 14, 14, │ 0 │ block5b_drop[0][… │ │ │ 112) │ │ block5a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_expand_conv │ (None, 14, 14, │ 75,264 │ block5b_add[0][0] │ │ (Conv2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_expand_bn │ (None, 14, 14, │ 2,688 │ block5c_expand_c… │ │ (BatchNormalizatio… │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_expand_act… │ (None, 14, 14, │ 0 │ block5c_expand_b… │ │ (Activation) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_dwconv │ (None, 14, 14, │ 16,800 │ block5c_expand_a… │ │ (DepthwiseConv2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_bn │ (None, 14, 14, │ 2,688 │ block5c_dwconv[0… │ │ (BatchNormalizatio… │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_activation │ (None, 14, 14, │ 0 │ block5c_bn[0][0] │ │ (Activation) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_squeeze │ (None, 672) │ 0 │ block5c_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_reshape │ (None, 1, 1, 672) │ 0 │ block5c_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_reduce │ (None, 1, 1, 28) │ 18,844 │ block5c_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_expand │ (None, 1, 1, 672) │ 19,488 │ block5c_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_excite │ (None, 14, 14, │ 0 │ block5c_activati… │ │ (Multiply) │ 672) │ │ block5c_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_project_co… │ (None, 14, 14, │ 75,264 │ block5c_se_excit… │ │ (Conv2D) │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_project_bn │ (None, 14, 14, │ 448 │ block5c_project_… │ │ (BatchNormalizatio… │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_drop │ (None, 14, 14, │ 0 │ block5c_project_… │ │ (Dropout) │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_add (Add) │ (None, 14, 14, │ 0 │ block5c_drop[0][… │ │ │ 112) │ │ block5b_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_expand_conv │ (None, 14, 14, │ 75,264 │ block5c_add[0][0] │ │ (Conv2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_expand_bn │ (None, 14, 14, │ 2,688 │ block6a_expand_c… │ │ (BatchNormalizatio… │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_expand_act… │ (None, 14, 14, │ 0 │ block6a_expand_b… │ │ (Activation) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_dwconv_pad │ (None, 17, 17, │ 0 │ block6a_expand_a… │ │ (ZeroPadding2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_dwconv │ (None, 7, 7, 672) │ 16,800 │ block6a_dwconv_p… │ │ (DepthwiseConv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_bn │ (None, 7, 7, 672) │ 2,688 │ block6a_dwconv[0… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_activation │ (None, 7, 7, 672) │ 0 │ block6a_bn[0][0] │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_squeeze │ (None, 672) │ 0 │ block6a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_reshape │ (None, 1, 1, 672) │ 0 │ block6a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_reduce │ (None, 1, 1, 28) │ 18,844 │ block6a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_expand │ (None, 1, 1, 672) │ 19,488 │ block6a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_excite │ (None, 7, 7, 672) │ 0 │ block6a_activati… │ │ (Multiply) │ │ │ block6a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_project_co… │ (None, 7, 7, 192) │ 129,024 │ block6a_se_excit… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_project_bn │ (None, 7, 7, 192) │ 768 │ block6a_project_… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_expand_conv │ (None, 7, 7, │ 221,184 │ block6a_project_… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_expand_bn │ (None, 7, 7, │ 4,608 │ block6b_expand_c… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_expand_act… │ (None, 7, 7, │ 0 │ block6b_expand_b… │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_dwconv │ (None, 7, 7, │ 28,800 │ block6b_expand_a… │ │ (DepthwiseConv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_bn │ (None, 7, 7, │ 4,608 │ block6b_dwconv[0… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_activation │ (None, 7, 7, │ 0 │ block6b_bn[0][0] │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_squeeze │ (None, 1152) │ 0 │ block6b_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_reshape │ (None, 1, 1, │ 0 │ block6b_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block6b_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_expand │ (None, 1, 1, │ 56,448 │ block6b_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_excite │ (None, 7, 7, │ 0 │ block6b_activati… │ │ (Multiply) │ 1152) │ │ block6b_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_project_co… │ (None, 7, 7, 192) │ 221,184 │ block6b_se_excit… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_project_bn │ (None, 7, 7, 192) │ 768 │ block6b_project_… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_drop │ (None, 7, 7, 192) │ 0 │ block6b_project_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_add (Add) │ (None, 7, 7, 192) │ 0 │ block6b_drop[0][… │ │ │ │ │ block6a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_expand_conv │ (None, 7, 7, │ 221,184 │ block6b_add[0][0] │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_expand_bn │ (None, 7, 7, │ 4,608 │ block6c_expand_c… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_expand_act… │ (None, 7, 7, │ 0 │ block6c_expand_b… │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_dwconv │ (None, 7, 7, │ 28,800 │ block6c_expand_a… │ │ (DepthwiseConv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_bn │ (None, 7, 7, │ 4,608 │ block6c_dwconv[0… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_activation │ (None, 7, 7, │ 0 │ block6c_bn[0][0] │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_squeeze │ (None, 1152) │ 0 │ block6c_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_reshape │ (None, 1, 1, │ 0 │ block6c_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block6c_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_expand │ (None, 1, 1, │ 56,448 │ block6c_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_excite │ (None, 7, 7, │ 0 │ block6c_activati… │ │ (Multiply) │ 1152) │ │ block6c_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_project_co… │ (None, 7, 7, 192) │ 221,184 │ block6c_se_excit… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_project_bn │ (None, 7, 7, 192) │ 768 │ block6c_project_… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_drop │ (None, 7, 7, 192) │ 0 │ block6c_project_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_add (Add) │ (None, 7, 7, 192) │ 0 │ block6c_drop[0][… │ │ │ │ │ block6b_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_expand_conv │ (None, 7, 7, │ 221,184 │ block6c_add[0][0] │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_expand_bn │ (None, 7, 7, │ 4,608 │ block6d_expand_c… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_expand_act… │ (None, 7, 7, │ 0 │ block6d_expand_b… │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_dwconv │ (None, 7, 7, │ 28,800 │ block6d_expand_a… │ │ (DepthwiseConv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_bn │ (None, 7, 7, │ 4,608 │ block6d_dwconv[0… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_activation │ (None, 7, 7, │ 0 │ block6d_bn[0][0] │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_squeeze │ (None, 1152) │ 0 │ block6d_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_reshape │ (None, 1, 1, │ 0 │ block6d_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block6d_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_expand │ (None, 1, 1, │ 56,448 │ block6d_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_excite │ (None, 7, 7, │ 0 │ block6d_activati… │ │ (Multiply) │ 1152) │ │ block6d_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_project_co… │ (None, 7, 7, 192) │ 221,184 │ block6d_se_excit… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_project_bn │ (None, 7, 7, 192) │ 768 │ block6d_project_… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_drop │ (None, 7, 7, 192) │ 0 │ block6d_project_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_add (Add) │ (None, 7, 7, 192) │ 0 │ block6d_drop[0][… │ │ │ │ │ block6c_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_expand_conv │ (None, 7, 7, │ 221,184 │ block6d_add[0][0] │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_expand_bn │ (None, 7, 7, │ 4,608 │ block7a_expand_c… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_expand_act… │ (None, 7, 7, │ 0 │ block7a_expand_b… │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_dwconv │ (None, 7, 7, │ 10,368 │ block7a_expand_a… │ │ (DepthwiseConv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_bn │ (None, 7, 7, │ 4,608 │ block7a_dwconv[0… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_activation │ (None, 7, 7, │ 0 │ block7a_bn[0][0] │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_se_squeeze │ (None, 1152) │ 0 │ block7a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_se_reshape │ (None, 1, 1, │ 0 │ block7a_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block7a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_se_expand │ (None, 1, 1, │ 56,448 │ block7a_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_se_excite │ (None, 7, 7, │ 0 │ block7a_activati… │ │ (Multiply) │ 1152) │ │ block7a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_project_co… │ (None, 7, 7, 320) │ 368,640 │ block7a_se_excit… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_project_bn │ (None, 7, 7, 320) │ 1,280 │ block7a_project_… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ top_conv (Conv2D) │ (None, 7, 7, │ 409,600 │ block7a_project_… │ │ │ 1280) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ top_bn │ (None, 7, 7, │ 5,120 │ top_conv[0][0] │ │ (BatchNormalizatio… │ 1280) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ top_activation │ (None, 7, 7, │ 0 │ top_bn[0][0] │ │ (Activation) │ 1280) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ global_average_poo… │ (None, 1280) │ 0 │ top_activation[0… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_7 (Dense) │ (None, 512) │ 655,872 │ global_average_p… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_8 (Dense) │ (None, 4) │ 2,052 │ dense_7[0][0] │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 4,707,495 (17.96 MB)
Trainable params: 4,665,472 (17.80 MB)
Non-trainable params: 42,023 (164.16 KB)
# Train the model
history = model.fit(train_generator,
epochs=50,
validation_data=val_generator,
steps_per_epoch=len(X_train) // 64,
validation_steps=len(X_valid) // 64)
Epoch 1/50
C:\Users\yanch\anaconda3\Lib\site-packages\keras\src\trainers\data_adapters\py_dataset_adapter.py:120: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored. self._warn_if_super_not_called()
36/36 ━━━━━━━━━━━━━━━━━━━━ 158s 3s/step - accuracy: 0.6149 - loss: 0.8603 - val_accuracy: 0.1016 - val_loss: 8.2727 Epoch 2/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 97s 3s/step - accuracy: 0.8646 - loss: 0.3835 - val_accuracy: 0.1484 - val_loss: 3.3987 Epoch 3/50
C:\Users\yanch\anaconda3\Lib\contextlib.py:155: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset. self.gen.throw(typ, value, traceback)
36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 66ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.1797 - val_loss: 3.4343 Epoch 4/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 105s 3s/step - accuracy: 0.9037 - loss: 0.2730 - val_accuracy: 0.1562 - val_loss: 3.9329 Epoch 5/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 96s 3s/step - accuracy: 0.8994 - loss: 0.2943 - val_accuracy: 0.1967 - val_loss: 5.9635 Epoch 6/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 67ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.1328 - val_loss: 6.3310 Epoch 7/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 107s 3s/step - accuracy: 0.8992 - loss: 0.2623 - val_accuracy: 0.1641 - val_loss: 6.7574 Epoch 8/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 99s 3s/step - accuracy: 0.9431 - loss: 0.1577 - val_accuracy: 0.1797 - val_loss: 2.4378 Epoch 9/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 62ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.1484 - val_loss: 2.4938 Epoch 10/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 106s 3s/step - accuracy: 0.9173 - loss: 0.2078 - val_accuracy: 0.1148 - val_loss: 7.2145 Epoch 11/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 100s 3s/step - accuracy: 0.9472 - loss: 0.1509 - val_accuracy: 0.1641 - val_loss: 4.5552 Epoch 12/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 63ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.1484 - val_loss: 4.7996 Epoch 13/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 108s 3s/step - accuracy: 0.9501 - loss: 0.1690 - val_accuracy: 0.1250 - val_loss: 7.1115 Epoch 14/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 99s 3s/step - accuracy: 0.9436 - loss: 0.1612 - val_accuracy: 0.0938 - val_loss: 2.9126 Epoch 15/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.1639 - val_loss: 3.1586 Epoch 16/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 106s 3s/step - accuracy: 0.9576 - loss: 0.1224 - val_accuracy: 0.1562 - val_loss: 5.2999 Epoch 17/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 99s 3s/step - accuracy: 0.9546 - loss: 0.1341 - val_accuracy: 0.2656 - val_loss: 2.9544 Epoch 18/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 63ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.2422 - val_loss: 3.0065 Epoch 19/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 117s 3s/step - accuracy: 0.9539 - loss: 0.1364 - val_accuracy: 0.3594 - val_loss: 2.2156 Epoch 20/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 106s 3s/step - accuracy: 0.9698 - loss: 0.0903 - val_accuracy: 0.3607 - val_loss: 2.2665 Epoch 21/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 3s 95ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.3984 - val_loss: 2.2333 Epoch 22/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 112s 3s/step - accuracy: 0.9586 - loss: 0.1315 - val_accuracy: 0.2734 - val_loss: 3.1006 Epoch 23/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 100s 3s/step - accuracy: 0.9690 - loss: 0.0914 - val_accuracy: 0.1406 - val_loss: 8.3236 Epoch 24/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 64ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.1484 - val_loss: 7.7363 Epoch 25/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 106s 3s/step - accuracy: 0.9597 - loss: 0.1182 - val_accuracy: 0.2131 - val_loss: 4.4671 Epoch 26/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 100s 3s/step - accuracy: 0.9551 - loss: 0.1098 - val_accuracy: 0.2969 - val_loss: 5.1090 Epoch 27/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 62ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.2422 - val_loss: 5.5664 Epoch 28/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 104s 3s/step - accuracy: 0.9691 - loss: 0.0851 - val_accuracy: 0.3438 - val_loss: 5.1524 Epoch 29/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 97s 3s/step - accuracy: 0.9559 - loss: 0.1273 - val_accuracy: 0.6016 - val_loss: 1.6378 Epoch 30/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 30ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.5246 - val_loss: 1.7923 Epoch 31/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 105s 3s/step - accuracy: 0.9711 - loss: 0.0962 - val_accuracy: 0.3906 - val_loss: 3.0544 Epoch 32/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 97s 3s/step - accuracy: 0.9670 - loss: 0.0949 - val_accuracy: 0.6875 - val_loss: 1.1604 Epoch 33/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 65ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.7344 - val_loss: 1.1340 Epoch 34/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 105s 3s/step - accuracy: 0.9729 - loss: 0.0753 - val_accuracy: 0.5469 - val_loss: 2.1677 Epoch 35/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 97s 3s/step - accuracy: 0.9640 - loss: 0.0922 - val_accuracy: 0.8033 - val_loss: 1.1701 Epoch 36/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 3s 83ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.7422 - val_loss: 1.2480 Epoch 37/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 104s 3s/step - accuracy: 0.9721 - loss: 0.0733 - val_accuracy: 0.7188 - val_loss: 1.3276 Epoch 38/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 98s 3s/step - accuracy: 0.9714 - loss: 0.0812 - val_accuracy: 0.7188 - val_loss: 1.5875 Epoch 39/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 64ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.7031 - val_loss: 1.5783 Epoch 40/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 103s 3s/step - accuracy: 0.9769 - loss: 0.0631 - val_accuracy: 0.8361 - val_loss: 0.3836 Epoch 41/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 98s 3s/step - accuracy: 0.9772 - loss: 0.0764 - val_accuracy: 0.7344 - val_loss: 1.1949 Epoch 42/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 65ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.7969 - val_loss: 0.9238 Epoch 43/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 105s 3s/step - accuracy: 0.9673 - loss: 0.1129 - val_accuracy: 0.8438 - val_loss: 0.4739 Epoch 44/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 96s 3s/step - accuracy: 0.9638 - loss: 0.0925 - val_accuracy: 0.4531 - val_loss: 2.6108 Epoch 45/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 35ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.4590 - val_loss: 2.5071 Epoch 46/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 105s 3s/step - accuracy: 0.9727 - loss: 0.0740 - val_accuracy: 0.2266 - val_loss: 3.5780 Epoch 47/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 96s 3s/step - accuracy: 0.9684 - loss: 0.0857 - val_accuracy: 0.9062 - val_loss: 0.4636 Epoch 48/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 2s 64ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.8828 - val_loss: 0.4826 Epoch 49/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 105s 3s/step - accuracy: 0.9912 - loss: 0.0501 - val_accuracy: 0.8906 - val_loss: 0.5076 Epoch 50/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 96s 3s/step - accuracy: 0.9843 - loss: 0.0448 - val_accuracy: 0.7377 - val_loss: 1.2556
# Evaluate the model
test_generator = valid_datagen.flow(X_test, y_test_new, batch_size=64)
test_loss, test_accuracy = model.evaluate(test_generator, steps=len(X_test) // 64)
print(f"Test Accuracy: {test_accuracy}")
10/10 ━━━━━━━━━━━━━━━━━━━━ 2s 156ms/step - accuracy: 0.1679 - loss: 1.3863 Test Accuracy: 0.16249999403953552
The training and validation loss plots, as well as the training and validation accuracy plots, show that there might be issues with overfitting and instability during training.
# Plotting training and validation loss
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
# Plotting training and validation accuracy
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import Dropout
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu', kernel_regularizer=l2(0.01))(x)
x = Dropout(0.5)(x)
predictions = Dense(num_classes, activation='softmax', kernel_regularizer=l2(0.01))(x)
model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ conv2d_10 (Conv2D) │ (None, 220, 220, 16) │ 1,216 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_10 (MaxPooling2D) │ (None, 110, 110, 16) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_11 (Conv2D) │ (None, 108, 108, 32) │ 4,640 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_11 (MaxPooling2D) │ (None, 54, 54, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_12 (Conv2D) │ (None, 52, 52, 64) │ 18,496 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_12 (MaxPooling2D) │ (None, 26, 26, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_13 (Conv2D) │ (None, 24, 24, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_13 (MaxPooling2D) │ (None, 12, 12, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_14 (Conv2D) │ (None, 10, 10, 256) │ 295,168 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_14 (MaxPooling2D) │ (None, 5, 5, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ flatten_2 (Flatten) │ (None, 6400) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_9 (Dense) │ (None, 512) │ 3,277,312 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_10 (Dense) │ (None, 4) │ 2,052 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 3,672,740 (14.01 MB)
Trainable params: 3,672,740 (14.01 MB)
Non-trainable params: 0 (0.00 B)
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.callbacks import EarlyStopping
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.00001)
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
# Train the model
history = model.fit(train_generator,
epochs=50,
validation_data=val_generator,
steps_per_epoch=len(X_train) // 64,
validation_steps=len(X_valid) // 64,
callbacks=[reduce_lr, early_stopping])
Epoch 1/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 42s 970ms/step - accuracy: 0.2520 - loss: 1.3695 - val_accuracy: 0.3281 - val_loss: 1.3465 - learning_rate: 1.0000e-04 Epoch 2/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 790ms/step - accuracy: 0.3925 - loss: 1.2717 - val_accuracy: 0.4609 - val_loss: 1.1761 - learning_rate: 1.0000e-04 Epoch 3/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.3828 - val_loss: 1.2153 - learning_rate: 1.0000e-04 Epoch 4/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 38s 914ms/step - accuracy: 0.4603 - loss: 1.1935 - val_accuracy: 0.4609 - val_loss: 1.2367 - learning_rate: 1.0000e-04 Epoch 5/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 27s 766ms/step - accuracy: 0.5027 - loss: 1.1517 - val_accuracy: 0.4754 - val_loss: 1.1391 - learning_rate: 1.0000e-04 Epoch 6/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 27ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.4375 - val_loss: 1.1860 - learning_rate: 1.0000e-04 Epoch 7/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 38s 917ms/step - accuracy: 0.5470 - loss: 1.0842 - val_accuracy: 0.5000 - val_loss: 1.1957 - learning_rate: 1.0000e-04 Epoch 8/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 770ms/step - accuracy: 0.5352 - loss: 1.0778 - val_accuracy: 0.5547 - val_loss: 1.1083 - learning_rate: 1.0000e-04 Epoch 9/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.4922 - val_loss: 1.2243 - learning_rate: 1.0000e-04 Epoch 10/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 36s 885ms/step - accuracy: 0.5688 - loss: 1.0169 - val_accuracy: 0.5082 - val_loss: 1.0406 - learning_rate: 1.0000e-04 Epoch 11/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 792ms/step - accuracy: 0.5924 - loss: 0.9813 - val_accuracy: 0.5078 - val_loss: 1.1500 - learning_rate: 1.0000e-04 Epoch 12/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 23ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.5000 - val_loss: 1.1776 - learning_rate: 1.0000e-04 Epoch 13/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 37s 895ms/step - accuracy: 0.6084 - loss: 0.9614 - val_accuracy: 0.5469 - val_loss: 1.1129 - learning_rate: 1.0000e-04 Epoch 14/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 27s 766ms/step - accuracy: 0.5915 - loss: 0.9620 - val_accuracy: 0.5234 - val_loss: 1.0463 - learning_rate: 1.0000e-04 Epoch 15/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.6557 - val_loss: 0.9987 - learning_rate: 1.0000e-04 Epoch 16/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 37s 914ms/step - accuracy: 0.6377 - loss: 0.8989 - val_accuracy: 0.4844 - val_loss: 1.1103 - learning_rate: 1.0000e-04 Epoch 17/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 27s 772ms/step - accuracy: 0.6393 - loss: 0.9007 - val_accuracy: 0.5156 - val_loss: 1.2403 - learning_rate: 1.0000e-04 Epoch 18/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.4531 - val_loss: 1.3733 - learning_rate: 1.0000e-04 Epoch 19/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 36s 883ms/step - accuracy: 0.6310 - loss: 0.8554 - val_accuracy: 0.4766 - val_loss: 1.1770 - learning_rate: 1.0000e-04 Epoch 20/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 779ms/step - accuracy: 0.6548 - loss: 0.8691 - val_accuracy: 0.5410 - val_loss: 0.9768 - learning_rate: 1.0000e-04 Epoch 21/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.5000 - val_loss: 1.0790 - learning_rate: 1.0000e-04 Epoch 22/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 36s 887ms/step - accuracy: 0.6322 - loss: 0.8440 - val_accuracy: 0.5625 - val_loss: 1.0311 - learning_rate: 1.0000e-04 Epoch 23/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 787ms/step - accuracy: 0.6509 - loss: 0.8023 - val_accuracy: 0.4375 - val_loss: 1.4593 - learning_rate: 1.0000e-04 Epoch 24/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.4062 - val_loss: 1.4555 - learning_rate: 1.0000e-04 Epoch 25/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 35s 868ms/step - accuracy: 0.6188 - loss: 0.8770 - val_accuracy: 0.4754 - val_loss: 1.2434 - learning_rate: 1.0000e-04 Epoch 26/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 29s 795ms/step - accuracy: 0.6503 - loss: 0.7951 - val_accuracy: 0.5469 - val_loss: 0.9752 - learning_rate: 2.0000e-05 Epoch 27/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.4844 - val_loss: 1.1067 - learning_rate: 2.0000e-05 Epoch 28/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 37s 901ms/step - accuracy: 0.6912 - loss: 0.7337 - val_accuracy: 0.4766 - val_loss: 1.1368 - learning_rate: 2.0000e-05 Epoch 29/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 27s 768ms/step - accuracy: 0.6807 - loss: 0.7792 - val_accuracy: 0.5156 - val_loss: 0.9521 - learning_rate: 2.0000e-05 Epoch 30/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.5082 - val_loss: 1.1625 - learning_rate: 2.0000e-05 Epoch 31/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 37s 909ms/step - accuracy: 0.6708 - loss: 0.7774 - val_accuracy: 0.6641 - val_loss: 0.9106 - learning_rate: 2.0000e-05 Epoch 32/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 27s 772ms/step - accuracy: 0.7000 - loss: 0.7276 - val_accuracy: 0.6094 - val_loss: 0.9920 - learning_rate: 2.0000e-05 Epoch 33/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 24ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.6328 - val_loss: 0.9810 - learning_rate: 2.0000e-05 Epoch 34/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 36s 888ms/step - accuracy: 0.7285 - loss: 0.7102 - val_accuracy: 0.5703 - val_loss: 1.0284 - learning_rate: 2.0000e-05 Epoch 35/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 27s 762ms/step - accuracy: 0.6813 - loss: 0.7785 - val_accuracy: 0.5738 - val_loss: 0.9935 - learning_rate: 2.0000e-05 Epoch 36/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 26ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.5938 - val_loss: 1.0313 - learning_rate: 2.0000e-05 Epoch 37/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 36s 890ms/step - accuracy: 0.7046 - loss: 0.7486 - val_accuracy: 0.5859 - val_loss: 0.9561 - learning_rate: 1.0000e-05 Epoch 38/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 773ms/step - accuracy: 0.6859 - loss: 0.7417 - val_accuracy: 0.4922 - val_loss: 1.3085 - learning_rate: 1.0000e-05 Epoch 39/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 1s 24ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00 - val_accuracy: 0.5547 - val_loss: 1.0669 - learning_rate: 1.0000e-05 Epoch 40/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 36s 887ms/step - accuracy: 0.6863 - loss: 0.7276 - val_accuracy: 0.4918 - val_loss: 1.1799 - learning_rate: 1.0000e-05 Epoch 41/50 36/36 ━━━━━━━━━━━━━━━━━━━━ 28s 787ms/step - accuracy: 0.7090 - loss: 0.7319 - val_accuracy: 0.5547 - val_loss: 0.9836 - learning_rate: 1.0000e-05
# Evaluate the model
test_generator = valid_datagen.flow(X_test, y_test_new, batch_size=64)
test_loss, test_accuracy = model.evaluate(test_generator, steps=len(X_test) // 64)
print(f"Test Accuracy: {test_accuracy}")
10/10 ━━━━━━━━━━━━━━━━━━━━ 2s 161ms/step - accuracy: 0.2872 - loss: 2.4204 Test Accuracy: 0.3031249940395355
The training and validation loss and accuracy plots indicate that the model is facing significant instability during training.
# Plotting training and validation loss
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
# Plotting training and validation accuracy
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
CNN Architecture 2 Run 3 has the highest accuracy
import pandas as pd
# Define data for Architecture 1
data2 = {
'Steps per Epoch': [5, 50, 100, 100, 37],
'Epochs': [10, 50, 50, 35, 100],
'Accuracy': [0.245, 0.9173, 0.9449, 0.928, 0.9096],
'Notes': ['', '', '', '', 'early stopped at epoch = 46']
}
# Define data for Architecture 2
data1 = {
'Steps per Epoch': [100],
'Epochs': [35],
'Accuracy': [0.925],
'Notes': ['']
}
# Create DataFrames
df1 = pd.DataFrame(data1, index=[f"Architecture 1 - Run {i+1}" for i in range(len(data1['Steps per Epoch']))])
df2 = pd.DataFrame(data2, index=[f"Architecture 2 (L2 Regularization) - Run {i+1}" for i in range(len(data2['Steps per Epoch']))])
# Concatenate both DataFrames
df = pd.concat([df1, df2])
df
| Steps per Epoch | Epochs | Accuracy | Notes | |
|---|---|---|---|---|
| Architecture 1 - Run 1 | 100 | 35 | 0.9250 | |
| Architecture 2 (L2 Regularization) - Run 1 | 5 | 10 | 0.2450 | |
| Architecture 2 (L2 Regularization) - Run 2 | 50 | 50 | 0.9173 | |
| Architecture 2 (L2 Regularization) - Run 3 | 100 | 50 | 0.9449 | |
| Architecture 2 (L2 Regularization) - Run 4 | 100 | 35 | 0.9280 | |
| Architecture 2 (L2 Regularization) - Run 5 | 37 | 100 | 0.9096 | early stopped at epoch = 46 |
Model Performance Analysis
Training and Validation Loss: The plot shows that while the training loss has consistently decreased and flattened (indicating good learning), the validation loss has some fluctuations but generally follows the training loss closely without diverging too much. This suggests that the model is not overfitting significantly.
Validation Accuracy: Peaked at approximately 95.04% during training, which is quite high. Test Accuracy: Even higher at 94.49%. This consistency between validation and test accuracy is a good sign of the model's ability to generalize well.
Precision and Recall: Very high across all classes, with Class 3 achieving perfect recall (1.00). This indicates that the model is very effective in identifying true positives for Class 3 without any false negatives.
F1-Score: Also high across all classes, suggesting a good balance between precision and recall. The weighted averages for accuracy, precision, recall, and F1-score are all above 0.94, which is excellent.
Observations
Model Stability: The model demonstrates stable performance across metrics, which is indicative of robust learning capabilities.
Loss Fluctuations: The fluctuations in validation loss could be indicative of potential minor overfitting or could simply be a result of the model navigating through complex loss landscapes. However, as they do not diverge significantly, this is not a major concern currently.
Augmentation for Underrepresented Classes: Increase the number of augmented images for the underrepresented class (no_tumor) to balance the dataset.
Class Weights: Utilize class weights in the model training process to give more importance to underrepresented classes during the loss calculation.
Oversampling/Undersampling: Consider oversampling the minority class or undersampling the majority classes.
The current augmentation strategy is robust, but we can experiment with less aggressive transformations for brain images, where orientation and structure are important. For example, a high rotation range might not be appropriate as brain tumors and their structures could be highly orientation-specific.
Depth and Complexity: As you're dealing with complex medical images, consider gradually increasing the complexity of the CNN. Incorporate deeper layers or additional convolutional blocks to capture more complex features.
Advanced Architectures: Explore more sophisticated architectures like ResNet, Inception, or DenseNet, which might be more effective for medical image analysis due to their deeper and more complex structures.
Once deployed, continuously monitor the model’s performance and establish a feedback loop with medical professionals to collect insights and further improve the model.
Before full deployment, ensure that the model undergoes thorough clinical validation to meet regulatory standards and to confirm that it performs well across different demographics and equipment variations.